Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import logging | |
| import re | |
| from datetime import datetime | |
| from typing import List, Dict, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from pymongo import MongoClient | |
| from bson import ObjectId | |
| import asyncio | |
| # Adjust sys path | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))) | |
| # TxAgent | |
| from txagent.txagent import TxAgent | |
| # MongoDB | |
| from db.mongo import get_mongo_client | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger("TxAgentAPI") | |
| # FastAPI app | |
| app = FastAPI(title="TxAgent API", version="2.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| # Models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| temperature: float = 0.7 | |
| max_new_tokens: int = 512 | |
| history: Optional[List[Dict]] = None | |
| format: Optional[str] = "clean" | |
| # Globals | |
| agent = None | |
| mongo_client = None | |
| patients_collection = None | |
| analysis_collection = None | |
| # Helpers | |
| def clean_text_response(text: str) -> str: | |
| text = re.sub(r'\n\s*\n', '\n\n', text) | |
| text = re.sub(r'[ ]+', ' ', text) | |
| return text.replace("**", "").replace("__", "").strip() | |
| def extract_section(text: str, heading: str) -> str: | |
| try: | |
| pattern = rf"{heading}:\n(.*?)(?=\n\w|\Z)" | |
| match = re.search(pattern, text, re.DOTALL) | |
| return clean_text_response(match.group(1)) if match else "" | |
| except Exception as e: | |
| logger.error(f"Section extraction failed: {e}") | |
| return "" | |
| def structure_medical_response(text: str) -> Dict: | |
| return { | |
| "summary": extract_section(text, "Summary"), | |
| "risks": extract_section(text, "Risks or Red Flags"), | |
| "missed_issues": extract_section(text, "What the doctor might have missed"), | |
| "recommendations": extract_section(text, "Suggested Clinical Actions") | |
| } | |
| def serialize_patient(patient: dict) -> dict: | |
| patient_copy = patient.copy() | |
| if "_id" in patient_copy: | |
| patient_copy["_id"] = str(patient_copy["_id"]) | |
| return patient_copy | |
| async def analyze_patient(patient: dict): | |
| try: | |
| doc = json.dumps(serialize_patient(patient), indent=2) | |
| message = ( | |
| "You are a clinical decision support AI.\n\n" | |
| "Given the patient document below:\n" | |
| "1. Summarize their medical history.\n" | |
| "2. Identify risks or red flags.\n" | |
| "3. Highlight missed diagnoses or treatments.\n" | |
| "4. Suggest next clinical steps.\n" | |
| f"\nPatient Document:\n{'-'*40}\n{doc[:10000]}" | |
| ) | |
| raw = agent.chat(message=message, history=[], temperature=0.7, max_new_tokens=1024) | |
| structured = structure_medical_response(raw) | |
| analysis_doc = { | |
| "patient_id": patient.get("fhir_id"), | |
| "timestamp": datetime.utcnow(), | |
| "summary": structured, | |
| "raw": raw | |
| } | |
| await analysis_collection.update_one( | |
| {"patient_id": patient.get("fhir_id")}, | |
| {"$set": analysis_doc}, | |
| upsert=True | |
| ) | |
| logger.info(f"✔️ Analysis stored for patient {patient.get('fhir_id')}") | |
| except Exception as e: | |
| logger.error(f"Error analyzing patient: {e}") | |
| async def analyze_all_patients(): | |
| patients = await patients_collection.find({}).to_list(length=None) | |
| for patient in patients: | |
| await analyze_patient(patient) | |
| await asyncio.sleep(0.1) | |
| # Startup logic | |
| async def startup_event(): | |
| global agent, mongo_client, patients_collection, analysis_collection | |
| # Init agent | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| enable_finish=True, | |
| enable_rag=False, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=42 | |
| ) | |
| agent.chat_prompt = ( | |
| "You are a clinical assistant AI. Analyze the patient's data and provide clear clinical recommendations." | |
| ) | |
| agent.init_model() | |
| logger.info("✅ TxAgent initialized") | |
| # MongoDB | |
| mongo_client = get_mongo_client() | |
| db = mongo_client.get_default_database() | |
| patients_collection = db.get_collection("patients") | |
| analysis_collection = db.get_collection("patient_analysis_results") | |
| logger.info("📡 Connected to MongoDB") | |
| asyncio.create_task(analyze_all_patients()) | |
| # Endpoints | |
| async def status(): | |
| return { | |
| "status": "running", | |
| "version": "2.1.0", | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| async def chat_stream_endpoint(request: ChatRequest): | |
| async def token_stream(): | |
| try: | |
| conversation = [{"role": "system", "content": agent.chat_prompt}] | |
| if request.history: | |
| conversation.extend(request.history) | |
| conversation.append({"role": "user", "content": request.message}) | |
| input_ids = agent.tokenizer.apply_chat_template( | |
| conversation, add_generation_prompt=True, return_tensors="pt" | |
| ).to(agent.device) | |
| output = agent.model.generate( | |
| input_ids, | |
| do_sample=True, | |
| temperature=request.temperature, | |
| max_new_tokens=request.max_new_tokens, | |
| pad_token_id=agent.tokenizer.eos_token_id, | |
| return_dict_in_generate=True | |
| ) | |
| text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) | |
| for chunk in text.split(): | |
| yield chunk + " " | |
| await asyncio.sleep(0.05) | |
| except Exception as e: | |
| logger.error(f"Streaming error: {e}") | |
| yield f"⚠️ Error: {e}" | |
| return StreamingResponse(token_stream(), media_type="text/plain") | |