Spaces:
Paused
Paused
| import json | |
| import os | |
| import re | |
| import time | |
| import shutil | |
| from typing import Dict, List, Any | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from gpt_researcher.document.document import DocumentLoader | |
| from backend.utils import write_md_to_pdf, write_md_to_word, write_text_to_md | |
| from pathlib import Path | |
| from datetime import datetime | |
| from fastapi import HTTPException | |
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| class CustomLogsHandler: | |
| """Custom handler to capture streaming logs from the research process""" | |
| def __init__(self, websocket, task: str): | |
| self.logs = [] | |
| self.websocket = websocket | |
| sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") | |
| self.log_file = os.path.join("/tmp/outputs", f"{sanitized_filename}.json") | |
| self.timestamp = datetime.now().isoformat() | |
| # Initialize log file with metadata | |
| os.makedirs("/tmp/outputs", exist_ok=True) | |
| with open(self.log_file, 'w') as f: | |
| json.dump({ | |
| "timestamp": self.timestamp, | |
| "events": [], | |
| "content": { | |
| "query": "", | |
| "sources": [], | |
| "context": [], | |
| "report": "", | |
| "costs": 0.0 | |
| } | |
| }, f, indent=2) | |
| async def send_json(self, data: Dict[str, Any]) -> None: | |
| """Store log data and send to websocket""" | |
| # Send to websocket for real-time display | |
| if self.websocket: | |
| await self.websocket.send_json(data) | |
| # Read current log file | |
| with open(self.log_file, 'r') as f: | |
| log_data = json.load(f) | |
| # Update appropriate section based on data type | |
| if data.get('type') == 'logs': | |
| log_data['events'].append({ | |
| "timestamp": datetime.now().isoformat(), | |
| "type": "event", | |
| "data": data | |
| }) | |
| else: | |
| # Update content section for other types of data | |
| log_data['content'].update(data) | |
| # Save updated log file | |
| with open(self.log_file, 'w') as f: | |
| json.dump(log_data, f, indent=2) | |
| logger.debug(f"Log entry written to: {self.log_file}") | |
| class Researcher: | |
| def __init__(self, query: str, report_type: str = "research_report"): | |
| self.query = query | |
| self.report_type = report_type | |
| # Generate unique ID for this research task | |
| self.research_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{hash(query)}" | |
| # Initialize logs handler with research ID | |
| self.logs_handler = CustomLogsHandler(self.research_id) | |
| self.researcher = GPTResearcher( | |
| query=query, | |
| report_type=report_type, | |
| websocket=self.logs_handler | |
| ) | |
| async def research(self) -> dict: | |
| """Conduct research and return paths to generated files""" | |
| await self.researcher.conduct_research() | |
| report = await self.researcher.write_report() | |
| # Generate the files | |
| sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{self.query}") | |
| file_paths = await generate_report_files(report, sanitized_filename) | |
| # Get the JSON log path that was created by CustomLogsHandler | |
| json_relative_path = os.path.relpath(self.logs_handler.log_file) | |
| return { | |
| "output": { | |
| **file_paths, # Include PDF, DOCX, and MD paths | |
| "json": json_relative_path | |
| } | |
| } | |
| def sanitize_filename(filename: str) -> str: | |
| # Split into components | |
| prefix, timestamp, *task_parts = filename.split('_') | |
| task = '_'.join(task_parts) | |
| # Calculate max length for task portion | |
| # 255 - len("/tmp/outputs/") - len("task_") - len(timestamp) - len("_.json") - safety_margin | |
| max_task_length = 255 - 8 - 5 - 10 - 6 - 10 # ~216 chars for task | |
| # Truncate task if needed | |
| truncated_task = task[:max_task_length] if len(task) > max_task_length else task | |
| # Reassemble and clean the filename | |
| sanitized = f"{prefix}_{timestamp}_{truncated_task}" | |
| return re.sub(r"[^\w\s-]", "", sanitized).strip() | |
| async def handle_start_command(websocket, data: str, manager): | |
| json_data = json.loads(data[6:]) | |
| task, report_type, source_urls, document_urls, tone, headers, report_source = extract_command_data( | |
| json_data) | |
| if not task or not report_type: | |
| print("Error: Missing task or report_type") | |
| return | |
| # Create logs handler with websocket and task | |
| logs_handler = CustomLogsHandler(websocket, task) | |
| # Initialize log content with query | |
| await logs_handler.send_json({ | |
| "query": task, | |
| "sources": [], | |
| "context": [], | |
| "report": "" | |
| }) | |
| sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") | |
| report = await manager.start_streaming( | |
| task, | |
| report_type, | |
| report_source, | |
| source_urls, | |
| document_urls, | |
| tone, | |
| websocket, | |
| headers | |
| ) | |
| report = str(report) | |
| file_paths = await generate_report_files(report, sanitized_filename) | |
| # Add JSON log path to file_paths | |
| file_paths["json"] = os.path.relpath(logs_handler.log_file) | |
| await send_file_paths(websocket, file_paths) | |
| async def handle_human_feedback(data: str): | |
| feedback_data = json.loads(data[14:]) # Remove "human_feedback" prefix | |
| print(f"Received human feedback: {feedback_data}") | |
| # TODO: Add logic to forward the feedback to the appropriate agent or update the research state | |
| async def handle_chat(websocket, data: str, manager): | |
| json_data = json.loads(data[4:]) | |
| print(f"Received chat message: {json_data.get('message')}") | |
| await manager.chat(json_data.get("message"), websocket) | |
| async def generate_report_files(report: str, filename: str) -> Dict[str, str]: | |
| pdf_path = await write_md_to_pdf(report, filename) | |
| docx_path = await write_md_to_word(report, filename) | |
| md_path = await write_text_to_md(report, filename) | |
| return {"pdf": pdf_path, "docx": docx_path, "md": md_path} | |
| async def send_file_paths(websocket, file_paths: Dict[str, str]): | |
| await websocket.send_json({"type": "path", "output": file_paths}) | |
| def get_config_dict( | |
| langchain_api_key: str, openai_api_key: str, tavily_api_key: str, | |
| google_api_key: str, google_cx_key: str, bing_api_key: str, | |
| searchapi_api_key: str, serpapi_api_key: str, serper_api_key: str, searx_url: str | |
| ) -> Dict[str, str]: | |
| return { | |
| "LANGCHAIN_API_KEY": langchain_api_key or os.getenv("LANGCHAIN_API_KEY", ""), | |
| "OPENAI_API_KEY": openai_api_key or os.getenv("OPENAI_API_KEY", ""), | |
| "TAVILY_API_KEY": tavily_api_key or os.getenv("TAVILY_API_KEY", ""), | |
| "GOOGLE_API_KEY": google_api_key or os.getenv("GOOGLE_API_KEY", ""), | |
| "GOOGLE_CX_KEY": google_cx_key or os.getenv("GOOGLE_CX_KEY", ""), | |
| "BING_API_KEY": bing_api_key or os.getenv("BING_API_KEY", ""), | |
| "SEARCHAPI_API_KEY": searchapi_api_key or os.getenv("SEARCHAPI_API_KEY", ""), | |
| "SERPAPI_API_KEY": serpapi_api_key or os.getenv("SERPAPI_API_KEY", ""), | |
| "SERPER_API_KEY": serper_api_key or os.getenv("SERPER_API_KEY", ""), | |
| "SEARX_URL": searx_url or os.getenv("SEARX_URL", ""), | |
| "LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"), | |
| "DOC_PATH": os.getenv("DOC_PATH", "/tmp/my-docs"), | |
| "RETRIEVER": os.getenv("RETRIEVER", ""), | |
| "EMBEDDING_MODEL": os.getenv("OPENAI_EMBEDDING_MODEL", "") | |
| } | |
| def update_environment_variables(config: Dict[str, str]): | |
| for key, value in config.items(): | |
| os.environ[key] = value | |
| async def handle_file_upload(file, DOC_PATH: str) -> Dict[str, str]: | |
| file_path = os.path.join(DOC_PATH, os.path.basename(file.filename)) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| print(f"File uploaded to {file_path}") | |
| document_loader = DocumentLoader(DOC_PATH) | |
| await document_loader.load() | |
| return {"filename": file.filename, "path": file_path} | |
| async def handle_file_deletion(filename: str, DOC_PATH: str) -> JSONResponse: | |
| file_path = os.path.join(DOC_PATH, os.path.basename(filename)) | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| print(f"File deleted: {file_path}") | |
| return JSONResponse(content={"message": "File deleted successfully"}) | |
| else: | |
| print(f"File not found: {file_path}") | |
| return JSONResponse(status_code=404, content={"message": "File not found"}) | |
| async def execute_multi_agents(manager) -> Any: | |
| websocket = manager.active_connections[0] if manager.active_connections else None | |
| if websocket: | |
| report = await run_research_task("Is AI in a hype cycle?", websocket, stream_output) | |
| return {"report": report} | |
| else: | |
| return JSONResponse(status_code=400, content={"message": "No active WebSocket connection"}) | |
| async def handle_websocket_communication(websocket, manager): | |
| while True: | |
| data = await websocket.receive_text() | |
| if data.startswith("start"): | |
| await handle_start_command(websocket, data, manager) | |
| elif data.startswith("human_feedback"): | |
| await handle_human_feedback(data) | |
| elif data.startswith("chat"): | |
| await handle_chat(websocket, data, manager) | |
| else: | |
| print("Error: Unknown command or not enough parameters provided.") | |
| def extract_command_data(json_data: Dict) -> tuple: | |
| return ( | |
| json_data.get("task"), | |
| json_data.get("report_type"), | |
| json_data.get("source_urls"), | |
| json_data.get("document_urls"), | |
| json_data.get("tone"), | |
| json_data.get("headers", {}), | |
| json_data.get("report_source") | |
| ) | |