Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import json | |
| import shutil | |
| import re | |
| import gc | |
| import time | |
| from datetime import datetime | |
| from typing import List, Tuple, Dict, Union, Optional | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import pandas as pd | |
| import pdfplumber | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from fpdf import FPDF | |
| import unicodedata | |
| import uvicorn | |
| # === Configuration === | |
| persistent_dir = "/data/hf_cache" | |
| model_cache_dir = os.path.join(persistent_dir, "txagent_models") | |
| tool_cache_dir = os.path.join(persistent_dir, "tool_cache") | |
| file_cache_dir = os.path.join(persistent_dir, "cache") | |
| report_dir = os.path.join(persistent_dir, "reports") | |
| # Create directories if they don't exist | |
| for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: | |
| os.makedirs(d, exist_ok=True) | |
| # Set environment variables | |
| os.environ["HF_HOME"] = model_cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = model_cache_dir | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" # Fix for matplotlib permission issues | |
| # Set up Python path | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_path = os.path.abspath(os.path.join(current_dir, "src")) | |
| sys.path.insert(0, src_path) | |
| # Import TxAgent after setting up paths | |
| from txagent.txagent import TxAgent | |
| # Constants | |
| MAX_MODEL_TOKENS = 131072 | |
| MAX_NEW_TOKENS = 4096 | |
| MAX_CHUNK_TOKENS = 8192 | |
| BATCH_SIZE = 1 | |
| PROMPT_OVERHEAD = 300 | |
| SAFE_SLEEP = 0.5 | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Clinical Patient Support System API", | |
| description="API for analyzing and summarizing unstructured medical files", | |
| version="1.0.0" | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize agent at startup | |
| agent = None | |
| async def startup_event(): | |
| global agent | |
| try: | |
| agent = init_agent() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to initialize agent: {str(e)}") | |
| def init_agent() -> TxAgent: | |
| """Initialize and return the TxAgent instance.""" | |
| tool_path = os.path.join(tool_cache_dir, "new_tool.json") | |
| if not os.path.exists(tool_path): | |
| shutil.copy(os.path.abspath("data/new_tool.json"), tool_path) | |
| agent = TxAgent( | |
| model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", | |
| tool_files_dict={"new_tool": tool_path}, | |
| force_finish=True, | |
| enable_checker=True, | |
| step_rag_num=4, | |
| seed=100 | |
| ) | |
| agent.init_model() | |
| return agent | |
| # Utility functions (keep your existing functions but add error handling) | |
| def estimate_tokens(text: str) -> int: | |
| """Estimate the number of tokens in the given text.""" | |
| return len(text) // 4 + 1 | |
| def clean_response(text: str) -> str: | |
| """Clean and format the response text.""" | |
| if not text: | |
| return "" | |
| text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| return text.strip() | |
| def extract_text_from_excel(path: str) -> str: | |
| """Extract text from Excel file.""" | |
| try: | |
| all_text = [] | |
| xls = pd.ExcelFile(path) | |
| for sheet_name in xls.sheet_names: | |
| try: | |
| df = xls.parse(sheet_name).astype(str).fillna("") | |
| except Exception: | |
| continue | |
| for _, row in df.iterrows(): | |
| non_empty = [cell.strip() for cell in row if cell.strip()] | |
| if len(non_empty) >= 2: | |
| text_line = " | ".join(non_empty) | |
| if len(text_line) > 15: | |
| all_text.append(f"[{sheet_name}] {text_line}") | |
| return "\n".join(all_text) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to extract text from Excel: {str(e)}") | |
| def extract_text(file_path: str) -> str: | |
| """Extract text from supported file types.""" | |
| try: | |
| if file_path.endswith(".xlsx"): | |
| return extract_text_from_excel(file_path) | |
| elif file_path.endswith(".csv"): | |
| df = pd.read_csv(file_path).astype(str).fillna("") | |
| return "\n".join( | |
| " | ".join(cell.strip() for cell in row if cell.strip()) | |
| for _, row in df.iterrows() | |
| if len([cell for cell in row if cell.strip()]) >= 2 | |
| ) | |
| elif file_path.endswith(".pdf"): | |
| with pdfplumber.open(file_path) as pdf: | |
| return "\n".join(page.extract_text() or "" for page in pdf.pages) | |
| else: | |
| return "" | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to extract text from file: {str(e)}") | |
| # API endpoints | |
| async def analyze_document(file: UploadFile = File(...)): | |
| """Analyze a medical document and return results.""" | |
| start_time = time.time() | |
| try: | |
| # Save the uploaded file temporarily | |
| temp_path = os.path.join(file_cache_dir, file.filename) | |
| with open(temp_path, "wb") as f: | |
| f.write(await file.read()) | |
| extracted = extract_text(temp_path) | |
| if not extracted: | |
| raise HTTPException(status_code=400, detail="Could not extract text from the file") | |
| chunks = split_text(extracted) | |
| batches = batch_chunks(chunks, batch_size=BATCH_SIZE) | |
| batch_results = analyze_batches(agent, batches) | |
| valid_results = [res for res in batch_results if not res.startswith("❌")] | |
| if not valid_results: | |
| raise HTTPException(status_code=400, detail="No valid analysis results were generated") | |
| final_summary = generate_final_summary(agent, "\n\n".join(valid_results)) | |
| # Generate report files | |
| report_filename = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| report_path = os.path.join(report_dir, f"{report_filename}.md") | |
| with open(report_path, 'w', encoding='utf-8') as f: | |
| f.write(f"# Final Medical Report\n\n{final_summary}") | |
| pdf_path = generate_pdf_report_with_charts(final_summary, report_path, detailed_batches=batch_results) | |
| # Clean up temp file | |
| os.remove(temp_path) | |
| return JSONResponse({ | |
| "status": "success", | |
| "summary": final_summary, | |
| "report_path": f"/reports/{os.path.basename(pdf_path)}", | |
| "processing_time": f"{time.time() - start_time:.2f} seconds", | |
| "detailed_outputs": batch_results | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def download_report(filename: str): | |
| """Download a generated report.""" | |
| file_path = os.path.join(report_dir, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Report not found") | |
| return FileResponse(file_path, media_type='application/pdf', filename=filename) | |
| async def service_status(): | |
| """Check service status.""" | |
| return { | |
| "status": "running", | |
| "version": "1.0.0", | |
| "model": "mims-harvard/TxAgent-T1-Llama-3.1-8B", | |
| "max_tokens": MAX_MODEL_TOKENS, | |
| "supported_file_types": [".pdf", ".xlsx", ".csv"] | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |