from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from langchain_community.llms import Ollama from langchain.prompts import PromptTemplate from typing import List, Dict, Optional, AsyncGenerator import json import asyncio from pathlib import Path import os import re from database import init_db from db_extensions import register_extensions # Import transformers for NER try: from transformers import pipeline TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False print("Warning: transformers not available, NER models will not work") app = FastAPI(title="Pub/Sub Multi-Agent System") db_ready = False con = None # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files static_dir = Path(__file__).parent / "static" static_dir.mkdir(exist_ok=True) app.mount("/static", StaticFiles(directory=static_dir), name="static") # Models class DataSource(BaseModel): label: str content: str subscribe_topic: Optional[str] = None class Agent(BaseModel): title: str prompt: str model: str subscribe_topic: str publish_topic: Optional[str] = None show_result: bool = False class ExecutionRequest(BaseModel): data_sources: List[DataSource] user_question: str = "" agents: List[Agent] class Query(BaseModel): sql: str # Pub/Sub Bus class MessageBus: def __init__(self): self.subscribers: Dict[str, List[Agent]] = {} self.datasource_subscribers: Dict[str, List[DataSource]] = {} self.messages: Dict[str, str] = {} def reset(self): """Reset the bus for a new execution""" self.subscribers = {} self.datasource_subscribers = {} self.messages = {} def _normalize_topic(self, topic: str) -> str: """Normalize topic to lowercase for case-insensitive matching""" return topic.lower().strip() def subscribe(self, topic: str, agent: Agent): """Subscribe an agent to a topic (case insensitive)""" normalized = self._normalize_topic(topic) if normalized not in self.subscribers: self.subscribers[normalized] = [] self.subscribers[normalized].append(agent) def subscribe_datasource(self, topic: str, datasource: DataSource): """Subscribe a data source to a topic (case insensitive)""" normalized = self._normalize_topic(topic) if normalized not in self.datasource_subscribers: self.datasource_subscribers[normalized] = [] self.datasource_subscribers[normalized].append(datasource) def publish(self, topic: str, content: str): """Publish a message to a topic (case insensitive)""" normalized = self._normalize_topic(topic) self.messages[normalized] = content def get_message(self, topic: str) -> Optional[str]: """Get message from a topic (case insensitive)""" normalized = self._normalize_topic(topic) return self.messages.get(normalized) def get_subscribers(self, topic: str) -> List[Agent]: """Get all subscribers for a topic (case insensitive)""" normalized = self._normalize_topic(topic) return self.subscribers.get(normalized, []) def get_datasource_subscribers(self, topic: str) -> List[DataSource]: """Get all data source subscribers for a topic (case insensitive)""" normalized = self._normalize_topic(topic) return self.datasource_subscribers.get(normalized, []) # Stream event helper def create_event(event_type: str, **kwargs): data = {"type": event_type, **kwargs} return f"data: {json.dumps(data)}\n\n" # Get LLM instance def get_llm(model_name: str): return Ollama(model=model_name, temperature=0.1) # NER pipeline cache _ner_pipelines = {} def get_ner_pipeline(model_name: str): """Get or create NER pipeline for the specified model""" if not TRANSFORMERS_AVAILABLE: raise RuntimeError("transformers package not available") if model_name not in _ner_pipelines: print(f"Loading NER model: {model_name}") _ner_pipelines[model_name] = pipeline( "ner", model=model_name, aggregation_strategy="simple" ) return _ner_pipelines[model_name] # Check if model is NER model def is_ner_model(model_name: str) -> bool: """Check if the model is an NER model""" ner_models = [ "samrawal/bert-base-uncased_clinical-ner", "OpenMed/OpenMed-NER-AnatomyDetect-BioPatient-108M" ] return model_name in ner_models # Check if model is SQL agent def is_sql_agent(model_name: str) -> bool: """Check if the model is SQL agent""" return model_name.upper() == "SQL" # Format NER output for display def format_ner_result(text: str, entities: List[Dict]) -> str: """Format NER entities for human-readable display""" if not entities: return text # Sort entities by start position in reverse to avoid index issues sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True) result = text for entity in sorted_entities: start = entity['start'] end = entity['end'] entity_type = entity['entity_group'] original_text = text[start:end] # Replace entity with labeled version labeled = f"[{original_text}:{entity_type}]" result = result[:start] + labeled + result[end:] return result # Execute SQL query def execute_sql_query(sql: str) -> tuple[str, Optional[Dict]]: """Execute SQL query and return JSON result + formatted dict""" if not db_ready: return {"error": "Database not ready"} try: result = con.execute(sql).fetchall() columns = [desc[0] for desc in con.description] formatted_result = { "columns": columns, "rows": result, "row_count": len(result) } json_output = json.dumps(formatted_result, indent=2) return json_output, formatted_result except Exception as e: error_msg = f"SQL execution failed: {str(e)}" return json.dumps({"error": error_msg}), None # Process NER with transformers pipeline def process_ner(text: str, model_name: str) -> tuple[str, List[Dict]]: """Process text with NER pipeline and return JSON + formatted entities""" try: ner_pipeline = get_ner_pipeline(model_name) # Run NER entities = ner_pipeline(text) # Convert to our format with proper type conversion formatted_entities = [] for entity in entities: formatted_entities.append({ "text": str(entity['word']), "entity_type": str(entity['entity_group']), "start": int(entity['start']), "end": int(entity['end']), "score": float(entity.get('score', 0.0)) # Convert numpy float32 to Python float }) # Create JSON output with proper serialization json_output = json.dumps(formatted_entities, indent=2) return json_output, formatted_entities except Exception as e: error_msg = f"NER processing failed: {str(e)}" return json.dumps({"error": error_msg}), [] # Execute agent async def execute_agent(agent: Agent, input_content: str, data_sources: List[DataSource], user_question: str) -> tuple[str, Optional[List[Dict]], Optional[str]]: """Execute a single agent with the given input. Returns (result, entities, analyzed_text) where entities is for NER models.""" # Case-insensitive replacement helper def replace_case_insensitive(text: str, placeholder: str, value: str) -> str: """Replace placeholder in text, case insensitive""" pattern = re.compile(re.escape(placeholder), re.IGNORECASE) return pattern.sub(value, text) # Start with the agent's prompt template prompt_text = agent.prompt if agent.prompt else "" # Replace standard placeholders (case insensitive) prompt_text = replace_case_insensitive(prompt_text, "{input}", input_content) prompt_text = replace_case_insensitive(prompt_text, "{question}", user_question) # Replace data source placeholders (case insensitive) for ds in data_sources: placeholder = "{" + ds.label + "}" prompt_text = replace_case_insensitive(prompt_text, placeholder, ds.content) # Check agent type if is_sql_agent(agent.model): # SQL agent: rendered prompt IS the SQL query sql_query = prompt_text.strip() # If prompt is empty, use input content as SQL if not sql_query: sql_query = input_content # Execute SQL query json_result, query_result = execute_sql_query(sql_query) # Return JSON result, query result dict, and the SQL that was executed return json_result, None, sql_query elif is_ner_model(agent.model): # For NER models, the rendered prompt IS the text to analyze text_to_analyze = prompt_text # If prompt is empty, use input content directly if not text_to_analyze.strip(): text_to_analyze = input_content # Process with NER pipeline json_result, entities = process_ner(text_to_analyze, agent.model) # Return JSON result, entities, and the text that was analyzed return json_result, entities, text_to_analyze else: # Regular LLM processing llm = get_llm(agent.model) # Invoke LLM with the rendered prompt result = llm.invoke(prompt_text) return (result if isinstance(result, str) else str(result)), None, None # Main execution pipeline async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]: try: bus = MessageBus() yield create_event("bus_init") # Reset and configure subscriptions bus.reset() # Subscribe all agents to their topics for agent in request.agents: if agent.subscribe_topic: bus.subscribe(agent.subscribe_topic, agent) yield create_event("agent_subscribed", agent=agent.title, topic=agent.subscribe_topic) # Subscribe data sources to their topics for datasource in request.data_sources: if datasource.subscribe_topic: bus.subscribe_datasource(datasource.subscribe_topic, datasource) yield create_event("datasource_subscribed", datasource=datasource.label, topic=datasource.subscribe_topic) # Publish START message start_message = request.user_question if request.user_question else "System initialized" bus.publish("START", start_message) yield create_event("message_published", topic="START", content=start_message) # Process messages in the bus processed_topics = set() max_iterations = 20 # Prevent infinite loops iteration = 0 while iteration < max_iterations: iteration += 1 # Find topics that have messages but haven't been processed topics_to_process = [topic for topic in bus.messages.keys() if topic not in processed_topics] if not topics_to_process: break for topic in topics_to_process: subscribers = bus.get_subscribers(topic) datasource_subscribers = bus.get_datasource_subscribers(topic) if not subscribers and not datasource_subscribers: yield create_event("no_subscribers", topic=topic) processed_topics.add(topic) continue message_content = bus.get_message(topic) # Update data sources that subscribe to this topic for datasource in datasource_subscribers: datasource.content = message_content yield create_event("datasource_updated", datasource=datasource.label, topic=topic, content=message_content) # Process agents that subscribe to this topic for agent in subscribers: yield create_event("agent_triggered", agent=agent.title, topic=topic) yield create_event("agent_processing", agent=agent.title) yield create_event("agent_input", content=message_content) # Execute agent try: result, entities, analyzed_text = await execute_agent(agent, message_content, request.data_sources, request.user_question) yield create_event("agent_output", content=result) # Special handling for SQL agents if is_sql_agent(agent.model) and analyzed_text: # analyzed_text contains the SQL query that was executed result_dict = json.loads(result) if not result.startswith("{\"error\"") else {"error": "Query failed"} if "row_count" in result_dict: yield create_event("sql_result", sql=analyzed_text, rows=result_dict["row_count"]) # If agent wants to show result, send it to frontend if agent.show_result: yield create_event("show_result", agent=agent.title, content=result) # If this is an NER agent with entities, also send formatted NER result if entities and is_ner_model(agent.model) and analyzed_text: formatted_text = format_ner_result(analyzed_text, entities) yield create_event("ner_result", agent=agent.title, formatted_text=formatted_text) # Publish result to agent's publish topic (if specified) if agent.publish_topic: bus.publish(agent.publish_topic, result) yield create_event("message_published", topic=agent.publish_topic, content=result) yield create_event("agent_completed", agent=agent.title) except Exception as e: yield create_event("error", message=f"Agent {agent.title} failed: {str(e)}") processed_topics.add(topic) yield create_event("execution_complete") except Exception as e: yield create_event("error", message=str(e)) @app.on_event("startup") def startup_event(): global con, db_ready print("Initializing database...") con = init_db() # Register SQL extensions register_extensions(con) db_ready = True print("Database ready.") @app.get("/") async def root(): """Serve the main web interface""" index_file = static_dir / "index.html" if index_file.exists(): return FileResponse(index_file) return {"message": "Place index.html in the static/ directory"} @app.post("/execute") async def execute(request: ExecutionRequest): """Execute the pub/sub agent system with streaming logs""" return StreamingResponse( execute_pipeline(request), media_type="text/event-stream" ) @app.get("/health") async def health(): """Health check endpoint""" return {"status": "ok"} @app.get("/status") def status(): return {"ready": db_ready} if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)