| | from fastapi import FastAPI |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse, FileResponse |
| | from fastapi.staticfiles import StaticFiles |
| | from pydantic import BaseModel |
| | from langchain_community.llms import Ollama |
| | from langchain.prompts import PromptTemplate |
| | from typing import List, Dict, Optional, AsyncGenerator |
| | import json |
| | import asyncio |
| | from pathlib import Path |
| | import os |
| | import re |
| |
|
| | from database import init_db |
| | from db_extensions import register_extensions |
| |
|
| | |
| | try: |
| | from transformers import pipeline |
| | TRANSFORMERS_AVAILABLE = True |
| | except ImportError: |
| | TRANSFORMERS_AVAILABLE = False |
| | print("Warning: transformers not available, NER models will not work") |
| |
|
| | app = FastAPI(title="Pub/Sub Multi-Agent System") |
| |
|
| | db_ready = False |
| | con = None |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | static_dir = Path(__file__).parent / "static" |
| | static_dir.mkdir(exist_ok=True) |
| | app.mount("/static", StaticFiles(directory=static_dir), name="static") |
| |
|
| | |
| | class DataSource(BaseModel): |
| | label: str |
| | content: str |
| | subscribe_topic: Optional[str] = None |
| |
|
| | class Agent(BaseModel): |
| | title: str |
| | prompt: str |
| | model: str |
| | subscribe_topic: str |
| | publish_topic: Optional[str] = None |
| | show_result: bool = False |
| |
|
| | class ExecutionRequest(BaseModel): |
| | data_sources: List[DataSource] |
| | user_question: str = "" |
| | agents: List[Agent] |
| |
|
| | class Query(BaseModel): |
| | sql: str |
| | |
| | |
| | class MessageBus: |
| | def __init__(self): |
| | self.subscribers: Dict[str, List[Agent]] = {} |
| | self.datasource_subscribers: Dict[str, List[DataSource]] = {} |
| | self.messages: Dict[str, str] = {} |
| | |
| | def reset(self): |
| | """Reset the bus for a new execution""" |
| | self.subscribers = {} |
| | self.datasource_subscribers = {} |
| | self.messages = {} |
| | |
| | def _normalize_topic(self, topic: str) -> str: |
| | """Normalize topic to lowercase for case-insensitive matching""" |
| | return topic.lower().strip() |
| | |
| | def subscribe(self, topic: str, agent: Agent): |
| | """Subscribe an agent to a topic (case insensitive)""" |
| | normalized = self._normalize_topic(topic) |
| | if normalized not in self.subscribers: |
| | self.subscribers[normalized] = [] |
| | self.subscribers[normalized].append(agent) |
| | |
| | def subscribe_datasource(self, topic: str, datasource: DataSource): |
| | """Subscribe a data source to a topic (case insensitive)""" |
| | normalized = self._normalize_topic(topic) |
| | if normalized not in self.datasource_subscribers: |
| | self.datasource_subscribers[normalized] = [] |
| | self.datasource_subscribers[normalized].append(datasource) |
| | |
| | def publish(self, topic: str, content: str): |
| | """Publish a message to a topic (case insensitive)""" |
| | normalized = self._normalize_topic(topic) |
| | self.messages[normalized] = content |
| | |
| | def get_message(self, topic: str) -> Optional[str]: |
| | """Get message from a topic (case insensitive)""" |
| | normalized = self._normalize_topic(topic) |
| | return self.messages.get(normalized) |
| | |
| | def get_subscribers(self, topic: str) -> List[Agent]: |
| | """Get all subscribers for a topic (case insensitive)""" |
| | normalized = self._normalize_topic(topic) |
| | return self.subscribers.get(normalized, []) |
| | |
| | def get_datasource_subscribers(self, topic: str) -> List[DataSource]: |
| | """Get all data source subscribers for a topic (case insensitive)""" |
| | normalized = self._normalize_topic(topic) |
| | return self.datasource_subscribers.get(normalized, []) |
| |
|
| | |
| | def create_event(event_type: str, **kwargs): |
| | data = {"type": event_type, **kwargs} |
| | return f"data: {json.dumps(data)}\n\n" |
| |
|
| | |
| | def get_llm(model_name: str): |
| | return Ollama(model=model_name, temperature=0.1) |
| |
|
| | |
| | _ner_pipelines = {} |
| |
|
| | def get_ner_pipeline(model_name: str): |
| | """Get or create NER pipeline for the specified model""" |
| | if not TRANSFORMERS_AVAILABLE: |
| | raise RuntimeError("transformers package not available") |
| | |
| | if model_name not in _ner_pipelines: |
| | print(f"Loading NER model: {model_name}") |
| | _ner_pipelines[model_name] = pipeline( |
| | "ner", |
| | model=model_name, |
| | aggregation_strategy="simple" |
| | ) |
| | return _ner_pipelines[model_name] |
| |
|
| | |
| | def is_ner_model(model_name: str) -> bool: |
| | """Check if the model is an NER model""" |
| | ner_models = [ |
| | "samrawal/bert-base-uncased_clinical-ner", |
| | "OpenMed/OpenMed-NER-AnatomyDetect-BioPatient-108M" |
| | ] |
| | return model_name in ner_models |
| |
|
| | |
| | def is_sql_agent(model_name: str) -> bool: |
| | """Check if the model is SQL agent""" |
| | return model_name.upper() == "SQL" |
| |
|
| | |
| | def format_ner_result(text: str, entities: List[Dict]) -> str: |
| | """Format NER entities for human-readable display""" |
| | if not entities: |
| | return text |
| | |
| | |
| | sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True) |
| | |
| | result = text |
| | for entity in sorted_entities: |
| | start = entity['start'] |
| | end = entity['end'] |
| | entity_type = entity['entity_group'] |
| | original_text = text[start:end] |
| | |
| | |
| | labeled = f"[{original_text}:{entity_type}]" |
| | result = result[:start] + labeled + result[end:] |
| | |
| | return result |
| |
|
| | |
| | def execute_sql_query(sql: str) -> tuple[str, Optional[Dict]]: |
| | """Execute SQL query and return JSON result + formatted dict""" |
| | if not db_ready: |
| | return {"error": "Database not ready"} |
| |
|
| | try: |
| | result = con.execute(sql).fetchall() |
| | columns = [desc[0] for desc in con.description] |
| |
|
| | formatted_result = { |
| | "columns": columns, |
| | "rows": result, |
| | "row_count": len(result) |
| | } |
| | json_output = json.dumps(formatted_result, indent=2) |
| | return json_output, formatted_result |
| | except Exception as e: |
| | error_msg = f"SQL execution failed: {str(e)}" |
| | return json.dumps({"error": error_msg}), None |
| |
|
| | |
| | def process_ner(text: str, model_name: str) -> tuple[str, List[Dict]]: |
| | """Process text with NER pipeline and return JSON + formatted entities""" |
| | try: |
| | ner_pipeline = get_ner_pipeline(model_name) |
| | |
| | |
| | entities = ner_pipeline(text) |
| | |
| | |
| | formatted_entities = [] |
| | for entity in entities: |
| | formatted_entities.append({ |
| | "text": str(entity['word']), |
| | "entity_type": str(entity['entity_group']), |
| | "start": int(entity['start']), |
| | "end": int(entity['end']), |
| | "score": float(entity.get('score', 0.0)) |
| | }) |
| | |
| | |
| | json_output = json.dumps(formatted_entities, indent=2) |
| | |
| | return json_output, formatted_entities |
| | |
| | except Exception as e: |
| | error_msg = f"NER processing failed: {str(e)}" |
| | return json.dumps({"error": error_msg}), [] |
| |
|
| | |
| | async def execute_agent(agent: Agent, input_content: str, data_sources: List[DataSource], user_question: str) -> tuple[str, Optional[List[Dict]], Optional[str]]: |
| | """Execute a single agent with the given input. Returns (result, entities, analyzed_text) where entities is for NER models.""" |
| | |
| | |
| | def replace_case_insensitive(text: str, placeholder: str, value: str) -> str: |
| | """Replace placeholder in text, case insensitive""" |
| | pattern = re.compile(re.escape(placeholder), re.IGNORECASE) |
| | return pattern.sub(value, text) |
| | |
| | |
| | prompt_text = agent.prompt if agent.prompt else "" |
| | |
| | |
| | prompt_text = replace_case_insensitive(prompt_text, "{input}", input_content) |
| | prompt_text = replace_case_insensitive(prompt_text, "{question}", user_question) |
| | |
| | |
| | for ds in data_sources: |
| | placeholder = "{" + ds.label + "}" |
| | prompt_text = replace_case_insensitive(prompt_text, placeholder, ds.content) |
| | |
| | |
| | if is_sql_agent(agent.model): |
| | |
| | sql_query = prompt_text.strip() |
| | |
| | |
| | if not sql_query: |
| | sql_query = input_content |
| | |
| | |
| | json_result, query_result = execute_sql_query(sql_query) |
| | |
| | |
| | return json_result, None, sql_query |
| | |
| | elif is_ner_model(agent.model): |
| | |
| | text_to_analyze = prompt_text |
| | |
| | |
| | if not text_to_analyze.strip(): |
| | text_to_analyze = input_content |
| | |
| | |
| | json_result, entities = process_ner(text_to_analyze, agent.model) |
| | |
| | |
| | return json_result, entities, text_to_analyze |
| | else: |
| | |
| | llm = get_llm(agent.model) |
| | |
| | |
| | result = llm.invoke(prompt_text) |
| | |
| | return (result if isinstance(result, str) else str(result)), None, None |
| |
|
| | |
| | async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]: |
| | try: |
| | bus = MessageBus() |
| | |
| | yield create_event("bus_init") |
| | |
| | |
| | bus.reset() |
| | |
| | |
| | for agent in request.agents: |
| | if agent.subscribe_topic: |
| | bus.subscribe(agent.subscribe_topic, agent) |
| | yield create_event("agent_subscribed", agent=agent.title, topic=agent.subscribe_topic) |
| | |
| | |
| | for datasource in request.data_sources: |
| | if datasource.subscribe_topic: |
| | bus.subscribe_datasource(datasource.subscribe_topic, datasource) |
| | yield create_event("datasource_subscribed", datasource=datasource.label, topic=datasource.subscribe_topic) |
| | |
| | |
| | start_message = request.user_question if request.user_question else "System initialized" |
| | bus.publish("START", start_message) |
| | yield create_event("message_published", topic="START", content=start_message) |
| | |
| | |
| | processed_topics = set() |
| | max_iterations = 20 |
| | iteration = 0 |
| | |
| | while iteration < max_iterations: |
| | iteration += 1 |
| | |
| | |
| | topics_to_process = [topic for topic in bus.messages.keys() if topic not in processed_topics] |
| | |
| | if not topics_to_process: |
| | break |
| | |
| | for topic in topics_to_process: |
| | subscribers = bus.get_subscribers(topic) |
| | datasource_subscribers = bus.get_datasource_subscribers(topic) |
| | |
| | if not subscribers and not datasource_subscribers: |
| | yield create_event("no_subscribers", topic=topic) |
| | processed_topics.add(topic) |
| | continue |
| | |
| | message_content = bus.get_message(topic) |
| | |
| | |
| | for datasource in datasource_subscribers: |
| | datasource.content = message_content |
| | yield create_event("datasource_updated", |
| | datasource=datasource.label, |
| | topic=topic, |
| | content=message_content) |
| | |
| | |
| | for agent in subscribers: |
| | yield create_event("agent_triggered", agent=agent.title, topic=topic) |
| | yield create_event("agent_processing", agent=agent.title) |
| | yield create_event("agent_input", content=message_content) |
| | |
| | |
| | try: |
| | result, entities, analyzed_text = await execute_agent(agent, message_content, request.data_sources, request.user_question) |
| | yield create_event("agent_output", content=result) |
| | |
| | |
| | if is_sql_agent(agent.model) and analyzed_text: |
| | |
| | result_dict = json.loads(result) if not result.startswith("{\"error\"") else {"error": "Query failed"} |
| | if "row_count" in result_dict: |
| | yield create_event("sql_result", |
| | sql=analyzed_text, |
| | rows=result_dict["row_count"]) |
| | |
| | |
| | if agent.show_result: |
| | yield create_event("show_result", agent=agent.title, content=result) |
| | |
| | |
| | if entities and is_ner_model(agent.model) and analyzed_text: |
| | formatted_text = format_ner_result(analyzed_text, entities) |
| | yield create_event("ner_result", agent=agent.title, formatted_text=formatted_text) |
| | |
| | |
| | if agent.publish_topic: |
| | bus.publish(agent.publish_topic, result) |
| | yield create_event("message_published", topic=agent.publish_topic, content=result) |
| | |
| | yield create_event("agent_completed", agent=agent.title) |
| | |
| | except Exception as e: |
| | yield create_event("error", message=f"Agent {agent.title} failed: {str(e)}") |
| | |
| | processed_topics.add(topic) |
| | |
| | yield create_event("execution_complete") |
| | |
| | except Exception as e: |
| | yield create_event("error", message=str(e)) |
| |
|
| | @app.on_event("startup") |
| | def startup_event(): |
| | global con, db_ready |
| |
|
| | print("Initializing database...") |
| |
|
| | con = init_db() |
| |
|
| | |
| | register_extensions(con) |
| |
|
| | db_ready = True |
| |
|
| | print("Database ready.") |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Serve the main web interface""" |
| | index_file = static_dir / "index.html" |
| | if index_file.exists(): |
| | return FileResponse(index_file) |
| | return {"message": "Place index.html in the static/ directory"} |
| |
|
| | @app.post("/execute") |
| | async def execute(request: ExecutionRequest): |
| | """Execute the pub/sub agent system with streaming logs""" |
| | return StreamingResponse( |
| | execute_pipeline(request), |
| | media_type="text/event-stream" |
| | ) |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | """Health check endpoint""" |
| | return {"status": "ok"} |
| |
|
| | @app.get("/status") |
| | def status(): |
| | return {"ready": db_ready} |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | port = int(os.environ.get("PORT", 7860)) |
| | uvicorn.run(app, host="0.0.0.0", port=port) |
| |
|