| from fastapi import FastAPI |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel |
| from langchain_community.llms import Ollama |
| from langchain.prompts import PromptTemplate |
| from typing import List, Dict, Optional, AsyncGenerator |
| import json |
| import asyncio |
| from pathlib import Path |
| import os |
| import re |
|
|
| from database import init_db |
| from db_extensions import register_extensions |
|
|
| |
| try: |
| from transformers import pipeline |
| TRANSFORMERS_AVAILABLE = True |
| except ImportError: |
| TRANSFORMERS_AVAILABLE = False |
| print("Warning: transformers not available, NER models will not work") |
|
|
| app = FastAPI(title="Pub/Sub Multi-Agent System") |
|
|
| db_ready = False |
| con = None |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| static_dir = Path(__file__).parent / "static" |
| static_dir.mkdir(exist_ok=True) |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
| |
| class DataSource(BaseModel): |
| label: str |
| content: str |
| subscribe_topic: Optional[str] = None |
|
|
| class Agent(BaseModel): |
| title: str |
| prompt: str |
| model: str |
| subscribe_topic: str |
| publish_topic: Optional[str] = None |
| show_result: bool = False |
|
|
| class ExecutionRequest(BaseModel): |
| data_sources: List[DataSource] |
| user_question: str = "" |
| agents: List[Agent] |
|
|
| class Query(BaseModel): |
| sql: str |
| |
| |
| class MessageBus: |
| def __init__(self): |
| self.subscribers: Dict[str, List[Agent]] = {} |
| self.datasource_subscribers: Dict[str, List[DataSource]] = {} |
| self.messages: Dict[str, str] = {} |
| |
| def reset(self): |
| """Reset the bus for a new execution""" |
| self.subscribers = {} |
| self.datasource_subscribers = {} |
| self.messages = {} |
| |
| def _normalize_topic(self, topic: str) -> str: |
| """Normalize topic to lowercase for case-insensitive matching""" |
| return topic.lower().strip() |
| |
| def subscribe(self, topic: str, agent: Agent): |
| """Subscribe an agent to a topic (case insensitive)""" |
| normalized = self._normalize_topic(topic) |
| if normalized not in self.subscribers: |
| self.subscribers[normalized] = [] |
| self.subscribers[normalized].append(agent) |
| |
| def subscribe_datasource(self, topic: str, datasource: DataSource): |
| """Subscribe a data source to a topic (case insensitive)""" |
| normalized = self._normalize_topic(topic) |
| if normalized not in self.datasource_subscribers: |
| self.datasource_subscribers[normalized] = [] |
| self.datasource_subscribers[normalized].append(datasource) |
| |
| def publish(self, topic: str, content: str): |
| """Publish a message to a topic (case insensitive)""" |
| normalized = self._normalize_topic(topic) |
| self.messages[normalized] = content |
| |
| def get_message(self, topic: str) -> Optional[str]: |
| """Get message from a topic (case insensitive)""" |
| normalized = self._normalize_topic(topic) |
| return self.messages.get(normalized) |
| |
| def get_subscribers(self, topic: str) -> List[Agent]: |
| """Get all subscribers for a topic (case insensitive)""" |
| normalized = self._normalize_topic(topic) |
| return self.subscribers.get(normalized, []) |
| |
| def get_datasource_subscribers(self, topic: str) -> List[DataSource]: |
| """Get all data source subscribers for a topic (case insensitive)""" |
| normalized = self._normalize_topic(topic) |
| return self.datasource_subscribers.get(normalized, []) |
|
|
| |
| def create_event(event_type: str, **kwargs): |
| data = {"type": event_type, **kwargs} |
| return f"data: {json.dumps(data)}\n\n" |
|
|
| |
| def get_llm(model_name: str): |
| return Ollama(model=model_name, temperature=0.1) |
|
|
| |
| _ner_pipelines = {} |
|
|
| def get_ner_pipeline(model_name: str): |
| """Get or create NER pipeline for the specified model""" |
| if not TRANSFORMERS_AVAILABLE: |
| raise RuntimeError("transformers package not available") |
| |
| if model_name not in _ner_pipelines: |
| print(f"Loading NER model: {model_name}") |
| _ner_pipelines[model_name] = pipeline( |
| "ner", |
| model=model_name, |
| aggregation_strategy="simple" |
| ) |
| return _ner_pipelines[model_name] |
|
|
| |
| def is_ner_model(model_name: str) -> bool: |
| """Check if the model is an NER model""" |
| ner_models = [ |
| "samrawal/bert-base-uncased_clinical-ner", |
| "OpenMed/OpenMed-NER-AnatomyDetect-BioPatient-108M" |
| ] |
| return model_name in ner_models |
|
|
| |
| def is_sql_agent(model_name: str) -> bool: |
| """Check if the model is SQL agent""" |
| return model_name.upper() == "SQL" |
|
|
| |
| def format_ner_result(text: str, entities: List[Dict]) -> str: |
| """Format NER entities for human-readable display""" |
| if not entities: |
| return text |
| |
| |
| sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True) |
| |
| result = text |
| for entity in sorted_entities: |
| start = entity['start'] |
| end = entity['end'] |
| entity_type = entity['entity_group'] |
| original_text = text[start:end] |
| |
| |
| labeled = f"[{original_text}:{entity_type}]" |
| result = result[:start] + labeled + result[end:] |
| |
| return result |
|
|
| |
| def execute_sql_query(sql: str) -> tuple[str, Optional[Dict]]: |
| """Execute SQL query and return JSON result + formatted dict""" |
| if not db_ready: |
| return {"error": "Database not ready"} |
|
|
| try: |
| result = con.execute(sql).fetchall() |
| columns = [desc[0] for desc in con.description] |
|
|
| formatted_result = { |
| "columns": columns, |
| "rows": result, |
| "row_count": len(result) |
| } |
| json_output = json.dumps(formatted_result, indent=2) |
| return json_output, formatted_result |
| except Exception as e: |
| error_msg = f"SQL execution failed: {str(e)}" |
| return json.dumps({"error": error_msg}), None |
|
|
| |
| def process_ner(text: str, model_name: str) -> tuple[str, List[Dict]]: |
| """Process text with NER pipeline and return JSON + formatted entities""" |
| try: |
| ner_pipeline = get_ner_pipeline(model_name) |
| |
| |
| entities = ner_pipeline(text) |
| |
| |
| formatted_entities = [] |
| for entity in entities: |
| formatted_entities.append({ |
| "text": str(entity['word']), |
| "entity_type": str(entity['entity_group']), |
| "start": int(entity['start']), |
| "end": int(entity['end']), |
| "score": float(entity.get('score', 0.0)) |
| }) |
| |
| |
| json_output = json.dumps(formatted_entities, indent=2) |
| |
| return json_output, formatted_entities |
| |
| except Exception as e: |
| error_msg = f"NER processing failed: {str(e)}" |
| return json.dumps({"error": error_msg}), [] |
|
|
| |
| async def execute_agent(agent: Agent, input_content: str, data_sources: List[DataSource], user_question: str) -> tuple[str, Optional[List[Dict]], Optional[str]]: |
| """Execute a single agent with the given input. Returns (result, entities, analyzed_text) where entities is for NER models.""" |
| |
| |
| def replace_case_insensitive(text: str, placeholder: str, value: str) -> str: |
| """Replace placeholder in text, case insensitive""" |
| pattern = re.compile(re.escape(placeholder), re.IGNORECASE) |
| return pattern.sub(value, text) |
| |
| |
| prompt_text = agent.prompt if agent.prompt else "" |
| |
| |
| prompt_text = replace_case_insensitive(prompt_text, "{input}", input_content) |
| prompt_text = replace_case_insensitive(prompt_text, "{question}", user_question) |
| |
| |
| for ds in data_sources: |
| placeholder = "{" + ds.label + "}" |
| prompt_text = replace_case_insensitive(prompt_text, placeholder, ds.content) |
| |
| |
| if is_sql_agent(agent.model): |
| |
| sql_query = prompt_text.strip() |
| |
| |
| if not sql_query: |
| sql_query = input_content |
| |
| |
| json_result, query_result = execute_sql_query(sql_query) |
| |
| |
| return json_result, None, sql_query |
| |
| elif is_ner_model(agent.model): |
| |
| text_to_analyze = prompt_text |
| |
| |
| if not text_to_analyze.strip(): |
| text_to_analyze = input_content |
| |
| |
| json_result, entities = process_ner(text_to_analyze, agent.model) |
| |
| |
| return json_result, entities, text_to_analyze |
| else: |
| |
| llm = get_llm(agent.model) |
| |
| |
| result = llm.invoke(prompt_text) |
| |
| return (result if isinstance(result, str) else str(result)), None, None |
|
|
| |
| async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]: |
| try: |
| bus = MessageBus() |
| |
| yield create_event("bus_init") |
| |
| |
| bus.reset() |
| |
| |
| for agent in request.agents: |
| if agent.subscribe_topic: |
| bus.subscribe(agent.subscribe_topic, agent) |
| yield create_event("agent_subscribed", agent=agent.title, topic=agent.subscribe_topic) |
| |
| |
| for datasource in request.data_sources: |
| if datasource.subscribe_topic: |
| bus.subscribe_datasource(datasource.subscribe_topic, datasource) |
| yield create_event("datasource_subscribed", datasource=datasource.label, topic=datasource.subscribe_topic) |
| |
| |
| start_message = request.user_question if request.user_question else "System initialized" |
| bus.publish("START", start_message) |
| yield create_event("message_published", topic="START", content=start_message) |
| |
| |
| processed_topics = set() |
| max_iterations = 20 |
| iteration = 0 |
| |
| while iteration < max_iterations: |
| iteration += 1 |
| |
| |
| topics_to_process = [topic for topic in bus.messages.keys() if topic not in processed_topics] |
| |
| if not topics_to_process: |
| break |
| |
| for topic in topics_to_process: |
| subscribers = bus.get_subscribers(topic) |
| datasource_subscribers = bus.get_datasource_subscribers(topic) |
| |
| if not subscribers and not datasource_subscribers: |
| yield create_event("no_subscribers", topic=topic) |
| processed_topics.add(topic) |
| continue |
| |
| message_content = bus.get_message(topic) |
| |
| |
| for datasource in datasource_subscribers: |
| datasource.content = message_content |
| yield create_event("datasource_updated", |
| datasource=datasource.label, |
| topic=topic, |
| content=message_content) |
| |
| |
| for agent in subscribers: |
| yield create_event("agent_triggered", agent=agent.title, topic=topic) |
| yield create_event("agent_processing", agent=agent.title) |
| yield create_event("agent_input", content=message_content) |
| |
| |
| try: |
| result, entities, analyzed_text = await execute_agent(agent, message_content, request.data_sources, request.user_question) |
| yield create_event("agent_output", content=result) |
| |
| |
| if is_sql_agent(agent.model) and analyzed_text: |
| |
| result_dict = json.loads(result) if not result.startswith("{\"error\"") else {"error": "Query failed"} |
| if "row_count" in result_dict: |
| yield create_event("sql_result", |
| sql=analyzed_text, |
| rows=result_dict["row_count"]) |
| |
| |
| if agent.show_result: |
| yield create_event("show_result", agent=agent.title, content=result) |
| |
| |
| if entities and is_ner_model(agent.model) and analyzed_text: |
| formatted_text = format_ner_result(analyzed_text, entities) |
| yield create_event("ner_result", agent=agent.title, formatted_text=formatted_text) |
| |
| |
| if agent.publish_topic: |
| bus.publish(agent.publish_topic, result) |
| yield create_event("message_published", topic=agent.publish_topic, content=result) |
| |
| yield create_event("agent_completed", agent=agent.title) |
| |
| except Exception as e: |
| yield create_event("error", message=f"Agent {agent.title} failed: {str(e)}") |
| |
| processed_topics.add(topic) |
| |
| yield create_event("execution_complete") |
| |
| except Exception as e: |
| yield create_event("error", message=str(e)) |
|
|
| @app.on_event("startup") |
| def startup_event(): |
| global con, db_ready |
|
|
| print("Initializing database...") |
|
|
| con = init_db() |
|
|
| |
| register_extensions(con) |
|
|
| db_ready = True |
|
|
| print("Database ready.") |
|
|
| @app.get("/") |
| async def root(): |
| """Serve the main web interface""" |
| index_file = static_dir / "index.html" |
| if index_file.exists(): |
| return FileResponse(index_file) |
| return {"message": "Place index.html in the static/ directory"} |
|
|
| @app.post("/execute") |
| async def execute(request: ExecutionRequest): |
| """Execute the pub/sub agent system with streaming logs""" |
| return StreamingResponse( |
| execute_pipeline(request), |
| media_type="text/event-stream" |
| ) |
|
|
| @app.get("/health") |
| async def health(): |
| """Health check endpoint""" |
| return {"status": "ok"} |
|
|
| @app.get("/status") |
| def status(): |
| return {"ready": db_ready} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| port = int(os.environ.get("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|