File size: 5,107 Bytes
79ae05b
9155a62
 
79ae05b
 
9635653
9155a62
 
 
15b6036
780df80
9155a62
2d96b3b
514da67
9155a62
15b6036
2d96b3b
 
 
 
 
15b6036
79ae05b
9155a62
9635653
501bdbe
9635653
 
 
9155a62
780df80
63a0765
 
27a07a9
63a0765
 
 
9155a62
15b6036
9635653
 
 
 
79ae05b
9635653
 
 
 
 
 
 
 
 
 
79ae05b
 
 
 
 
 
 
 
 
 
 
 
 
9635653
 
 
 
 
79ae05b
 
 
9635653
 
2d96b3b
 
15b6036
9635653
15b6036
9635653
2d96b3b
 
15b6036
9635653
15b6036
9635653
2d96b3b
 
15b6036
 
2d96b3b
15b6036
2d96b3b
 
15b6036
 
2d96b3b
63a0765
4e3ab6e
88139f0
 
 
 
 
15b6036
4e3ab6e
2d96b3b
73fba58
 
4e3ab6e
79ae05b
 
 
 
 
 
 
780df80
 
2d96b3b
 
780df80
 
2d96b3b
 
79ae05b
 
 
 
 
 
15b6036
 
63a0765
79ae05b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15b6036
9155a62
2d96b3b
 
15b6036
9155a62
ba93af8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import asyncio
import json
import logging
import os
import time
from typing import Dict

import socketio
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from knet import KNet
from scraper import CrawlForAIScraper

load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()
CORS_ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", ",").split(",")
app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

sio = socketio.AsyncServer(
    cors_allowed_origins=CORS_ALLOWED_ORIGINS,
    ping_timeout=1200,
    ping_interval=10,
    async_mode="asgi",
)
app.mount("/", socketio.ASGIApp(sio))


class SessionManager:
    def __init__(self):
        self.sessions: Dict[str, tuple[KNet, CrawlForAIScraper]] = {}
        self.tasks: Dict[str, asyncio.Task] = {}  # Track research tasks for each session

    async def get_or_create_session(self, sid: str) -> tuple[KNet, CrawlForAIScraper]:
        if sid not in self.sessions:
            scraper = CrawlForAIScraper()
            await scraper.start()
            knet = KNet(scraper)
            self.sessions[sid] = (knet, scraper)
        return self.sessions[sid]

    async def cleanup_session(self, sid: str):
        # Cancel running task if it exists
        if sid in self.tasks and not self.tasks[sid].done():
            self.tasks[sid].cancel()
            try:
                await self.tasks[sid]
            except asyncio.CancelledError:
                logger.info(f"Research task for session {sid} was cancelled")
            except Exception as e:
                logger.error(f"Error while cancelling task for {sid}: {str(e)}")
            finally:
                del self.tasks[sid]

        # Clean up session resources
        if sid in self.sessions:
            _, scraper = self.sessions[sid]
            await scraper.close()
            del self.sessions[sid]

    def register_task(self, sid: str, task: asyncio.Task):
        self.tasks[sid] = task


session_manager = SessionManager()


@sio.event
async def connect(sid, environ, auth):
    logger.info(f"Client connected: {sid}")
    await session_manager.get_or_create_session(sid)


@sio.event
async def disconnect(sid, reason):
    logger.info(f"Client disconnected: {sid}")
    await session_manager.cleanup_session(sid)


@sio.event
async def health_check(sid, data):
    logger.debug("Health check received")
    await sio.emit("health_check", {"status": "ok"}, room=sid)


@sio.event
async def start_research(sid, data):
    try:
        data = json.loads(data) if type(data) is not dict else data
        topic = data.get("topic").strip()
        max_depth: int = data.get("max_depth")
        num_sites_per_query: int = data.get("num_sites_per_query")

        knet, _ = await session_manager.get_or_create_session(sid)

        session_id = sid
        logger.info(f"Starting research for client {session_id}.\nTopic '{topic}'")

        async def progress_callback(status: dict):
            await sio.emit("status", status, room=session_id)

        task = asyncio.create_task(knet.conduct_research(topic, progress_callback, max_depth, num_sites_per_query))
        session_manager.register_task(sid, task)
        research_results = await task

        if not research_results:
            sio.emit("research_aborted", room=session_id)

        logger.info(f"Research completed for topic: {topic}")
        await sio.emit("research_complete", research_results, room=session_id)

    except Exception as e:
        logger.error(f"Research error: {str(e)}")
        await sio.emit("error", {"message": str(e)}, room=session_id)


@sio.event
async def abort_research(sid):
    logger.info(f"Aborting research for client {sid}")
    await session_manager.cleanup_session(sid)


@sio.event
async def test(sid, data):
    data = json.loads(data) if type(data) is not dict else data
    topic = data.get("topic").strip().replace("\n", "")
    logger.info(json.dumps(data, indent=2))

    knet, _ = await session_manager.get_or_create_session(sid)
    time.sleep(1)

    async def progress_callback(status: dict):
        await sio.emit("status", status, room=sid)

    # Create a task and register it for proper cancellation
    task = asyncio.create_task(knet.test(topic, progress_callback))
    session_manager.register_task(sid, task)

    try:
        await task

        with open("output.log.json", "r") as f:
            data = json.load(f)
        await sio.emit("research_complete", data, room=sid)
    except asyncio.CancelledError:
        logger.info(f"Test task for '{topic}' was cancelled")
        await sio.emit("research_aborted", room=sid)
    except Exception as e:
        logger.error(f"Test error: {str(e)}")
        await sio.emit("error", {"message": str(e)}, room=sid)


if __name__ == "__main__":
    logger.info("Starting KnowledgeNet server...")
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=5000)