Spaces:
Paused
Paused
File size: 5,107 Bytes
79ae05b 9155a62 79ae05b 9635653 9155a62 15b6036 780df80 9155a62 2d96b3b 514da67 9155a62 15b6036 2d96b3b 15b6036 79ae05b 9155a62 9635653 501bdbe 9635653 9155a62 780df80 63a0765 27a07a9 63a0765 9155a62 15b6036 9635653 79ae05b 9635653 79ae05b 9635653 79ae05b 9635653 2d96b3b 15b6036 9635653 15b6036 9635653 2d96b3b 15b6036 9635653 15b6036 9635653 2d96b3b 15b6036 2d96b3b 15b6036 2d96b3b 15b6036 2d96b3b 63a0765 4e3ab6e 88139f0 15b6036 4e3ab6e 2d96b3b 73fba58 4e3ab6e 79ae05b 780df80 2d96b3b 780df80 2d96b3b 79ae05b 15b6036 63a0765 79ae05b 15b6036 9155a62 2d96b3b 15b6036 9155a62 ba93af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import asyncio
import json
import logging
import os
import time
from typing import Dict
import socketio
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from knet import KNet
from scraper import CrawlForAIScraper
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
CORS_ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", ",").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
sio = socketio.AsyncServer(
cors_allowed_origins=CORS_ALLOWED_ORIGINS,
ping_timeout=1200,
ping_interval=10,
async_mode="asgi",
)
app.mount("/", socketio.ASGIApp(sio))
class SessionManager:
def __init__(self):
self.sessions: Dict[str, tuple[KNet, CrawlForAIScraper]] = {}
self.tasks: Dict[str, asyncio.Task] = {} # Track research tasks for each session
async def get_or_create_session(self, sid: str) -> tuple[KNet, CrawlForAIScraper]:
if sid not in self.sessions:
scraper = CrawlForAIScraper()
await scraper.start()
knet = KNet(scraper)
self.sessions[sid] = (knet, scraper)
return self.sessions[sid]
async def cleanup_session(self, sid: str):
# Cancel running task if it exists
if sid in self.tasks and not self.tasks[sid].done():
self.tasks[sid].cancel()
try:
await self.tasks[sid]
except asyncio.CancelledError:
logger.info(f"Research task for session {sid} was cancelled")
except Exception as e:
logger.error(f"Error while cancelling task for {sid}: {str(e)}")
finally:
del self.tasks[sid]
# Clean up session resources
if sid in self.sessions:
_, scraper = self.sessions[sid]
await scraper.close()
del self.sessions[sid]
def register_task(self, sid: str, task: asyncio.Task):
self.tasks[sid] = task
session_manager = SessionManager()
@sio.event
async def connect(sid, environ, auth):
logger.info(f"Client connected: {sid}")
await session_manager.get_or_create_session(sid)
@sio.event
async def disconnect(sid, reason):
logger.info(f"Client disconnected: {sid}")
await session_manager.cleanup_session(sid)
@sio.event
async def health_check(sid, data):
logger.debug("Health check received")
await sio.emit("health_check", {"status": "ok"}, room=sid)
@sio.event
async def start_research(sid, data):
try:
data = json.loads(data) if type(data) is not dict else data
topic = data.get("topic").strip()
max_depth: int = data.get("max_depth")
num_sites_per_query: int = data.get("num_sites_per_query")
knet, _ = await session_manager.get_or_create_session(sid)
session_id = sid
logger.info(f"Starting research for client {session_id}.\nTopic '{topic}'")
async def progress_callback(status: dict):
await sio.emit("status", status, room=session_id)
task = asyncio.create_task(knet.conduct_research(topic, progress_callback, max_depth, num_sites_per_query))
session_manager.register_task(sid, task)
research_results = await task
if not research_results:
sio.emit("research_aborted", room=session_id)
logger.info(f"Research completed for topic: {topic}")
await sio.emit("research_complete", research_results, room=session_id)
except Exception as e:
logger.error(f"Research error: {str(e)}")
await sio.emit("error", {"message": str(e)}, room=session_id)
@sio.event
async def abort_research(sid):
logger.info(f"Aborting research for client {sid}")
await session_manager.cleanup_session(sid)
@sio.event
async def test(sid, data):
data = json.loads(data) if type(data) is not dict else data
topic = data.get("topic").strip().replace("\n", "")
logger.info(json.dumps(data, indent=2))
knet, _ = await session_manager.get_or_create_session(sid)
time.sleep(1)
async def progress_callback(status: dict):
await sio.emit("status", status, room=sid)
# Create a task and register it for proper cancellation
task = asyncio.create_task(knet.test(topic, progress_callback))
session_manager.register_task(sid, task)
try:
await task
with open("output.log.json", "r") as f:
data = json.load(f)
await sio.emit("research_complete", data, room=sid)
except asyncio.CancelledError:
logger.info(f"Test task for '{topic}' was cancelled")
await sio.emit("research_aborted", room=sid)
except Exception as e:
logger.error(f"Test error: {str(e)}")
await sio.emit("error", {"message": str(e)}, room=sid)
if __name__ == "__main__":
logger.info("Starting KnowledgeNet server...")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
|