santanche's picture
fix (dbvec): imports and data
10fc8ae
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
from typing import List, Dict, Optional, AsyncGenerator
import json
import asyncio
from pathlib import Path
import os
import re
from database import init_db
from db_extensions import register_extensions
# Import transformers for NER
try:
from transformers import pipeline
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("Warning: transformers not available, NER models will not work")
app = FastAPI(title="Pub/Sub Multi-Agent System")
db_ready = False
con = None
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
static_dir = Path(__file__).parent / "static"
static_dir.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Models
class DataSource(BaseModel):
label: str
content: str
subscribe_topic: Optional[str] = None
class Agent(BaseModel):
title: str
prompt: str
model: str
subscribe_topic: str
publish_topic: Optional[str] = None
show_result: bool = False
class ExecutionRequest(BaseModel):
data_sources: List[DataSource]
user_question: str = ""
agents: List[Agent]
class Query(BaseModel):
sql: str
# Pub/Sub Bus
class MessageBus:
def __init__(self):
self.subscribers: Dict[str, List[Agent]] = {}
self.datasource_subscribers: Dict[str, List[DataSource]] = {}
self.messages: Dict[str, str] = {}
def reset(self):
"""Reset the bus for a new execution"""
self.subscribers = {}
self.datasource_subscribers = {}
self.messages = {}
def _normalize_topic(self, topic: str) -> str:
"""Normalize topic to lowercase for case-insensitive matching"""
return topic.lower().strip()
def subscribe(self, topic: str, agent: Agent):
"""Subscribe an agent to a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
if normalized not in self.subscribers:
self.subscribers[normalized] = []
self.subscribers[normalized].append(agent)
def subscribe_datasource(self, topic: str, datasource: DataSource):
"""Subscribe a data source to a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
if normalized not in self.datasource_subscribers:
self.datasource_subscribers[normalized] = []
self.datasource_subscribers[normalized].append(datasource)
def publish(self, topic: str, content: str):
"""Publish a message to a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
self.messages[normalized] = content
def get_message(self, topic: str) -> Optional[str]:
"""Get message from a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
return self.messages.get(normalized)
def get_subscribers(self, topic: str) -> List[Agent]:
"""Get all subscribers for a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
return self.subscribers.get(normalized, [])
def get_datasource_subscribers(self, topic: str) -> List[DataSource]:
"""Get all data source subscribers for a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
return self.datasource_subscribers.get(normalized, [])
# Stream event helper
def create_event(event_type: str, **kwargs):
data = {"type": event_type, **kwargs}
return f"data: {json.dumps(data)}\n\n"
# Get LLM instance
def get_llm(model_name: str):
return Ollama(model=model_name, temperature=0.1)
# NER pipeline cache
_ner_pipelines = {}
def get_ner_pipeline(model_name: str):
"""Get or create NER pipeline for the specified model"""
if not TRANSFORMERS_AVAILABLE:
raise RuntimeError("transformers package not available")
if model_name not in _ner_pipelines:
print(f"Loading NER model: {model_name}")
_ner_pipelines[model_name] = pipeline(
"ner",
model=model_name,
aggregation_strategy="simple"
)
return _ner_pipelines[model_name]
# Check if model is NER model
def is_ner_model(model_name: str) -> bool:
"""Check if the model is an NER model"""
ner_models = [
"samrawal/bert-base-uncased_clinical-ner",
"OpenMed/OpenMed-NER-AnatomyDetect-BioPatient-108M"
]
return model_name in ner_models
# Check if model is SQL agent
def is_sql_agent(model_name: str) -> bool:
"""Check if the model is SQL agent"""
return model_name.upper() == "SQL"
# Format NER output for display
def format_ner_result(text: str, entities: List[Dict]) -> str:
"""Format NER entities for human-readable display"""
if not entities:
return text
# Sort entities by start position in reverse to avoid index issues
sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in sorted_entities:
start = entity['start']
end = entity['end']
entity_type = entity['entity_group']
original_text = text[start:end]
# Replace entity with labeled version
labeled = f"[{original_text}:{entity_type}]"
result = result[:start] + labeled + result[end:]
return result
# Execute SQL query
def execute_sql_query(sql: str) -> tuple[str, Optional[Dict]]:
"""Execute SQL query and return JSON result + formatted dict"""
if not db_ready:
return {"error": "Database not ready"}
try:
result = con.execute(sql).fetchall()
columns = [desc[0] for desc in con.description]
formatted_result = {
"columns": columns,
"rows": result,
"row_count": len(result)
}
json_output = json.dumps(formatted_result, indent=2)
return json_output, formatted_result
except Exception as e:
error_msg = f"SQL execution failed: {str(e)}"
return json.dumps({"error": error_msg}), None
# Process NER with transformers pipeline
def process_ner(text: str, model_name: str) -> tuple[str, List[Dict]]:
"""Process text with NER pipeline and return JSON + formatted entities"""
try:
ner_pipeline = get_ner_pipeline(model_name)
# Run NER
entities = ner_pipeline(text)
# Convert to our format with proper type conversion
formatted_entities = []
for entity in entities:
formatted_entities.append({
"text": str(entity['word']),
"entity_type": str(entity['entity_group']),
"start": int(entity['start']),
"end": int(entity['end']),
"score": float(entity.get('score', 0.0)) # Convert numpy float32 to Python float
})
# Create JSON output with proper serialization
json_output = json.dumps(formatted_entities, indent=2)
return json_output, formatted_entities
except Exception as e:
error_msg = f"NER processing failed: {str(e)}"
return json.dumps({"error": error_msg}), []
# Execute agent
async def execute_agent(agent: Agent, input_content: str, data_sources: List[DataSource], user_question: str) -> tuple[str, Optional[List[Dict]], Optional[str]]:
"""Execute a single agent with the given input. Returns (result, entities, analyzed_text) where entities is for NER models."""
# Case-insensitive replacement helper
def replace_case_insensitive(text: str, placeholder: str, value: str) -> str:
"""Replace placeholder in text, case insensitive"""
pattern = re.compile(re.escape(placeholder), re.IGNORECASE)
return pattern.sub(value, text)
# Start with the agent's prompt template
prompt_text = agent.prompt if agent.prompt else ""
# Replace standard placeholders (case insensitive)
prompt_text = replace_case_insensitive(prompt_text, "{input}", input_content)
prompt_text = replace_case_insensitive(prompt_text, "{question}", user_question)
# Replace data source placeholders (case insensitive)
for ds in data_sources:
placeholder = "{" + ds.label + "}"
prompt_text = replace_case_insensitive(prompt_text, placeholder, ds.content)
# Check agent type
if is_sql_agent(agent.model):
# SQL agent: rendered prompt IS the SQL query
sql_query = prompt_text.strip()
# If prompt is empty, use input content as SQL
if not sql_query:
sql_query = input_content
# Execute SQL query
json_result, query_result = execute_sql_query(sql_query)
# Return JSON result, query result dict, and the SQL that was executed
return json_result, None, sql_query
elif is_ner_model(agent.model):
# For NER models, the rendered prompt IS the text to analyze
text_to_analyze = prompt_text
# If prompt is empty, use input content directly
if not text_to_analyze.strip():
text_to_analyze = input_content
# Process with NER pipeline
json_result, entities = process_ner(text_to_analyze, agent.model)
# Return JSON result, entities, and the text that was analyzed
return json_result, entities, text_to_analyze
else:
# Regular LLM processing
llm = get_llm(agent.model)
# Invoke LLM with the rendered prompt
result = llm.invoke(prompt_text)
return (result if isinstance(result, str) else str(result)), None, None
# Main execution pipeline
async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]:
try:
bus = MessageBus()
yield create_event("bus_init")
# Reset and configure subscriptions
bus.reset()
# Subscribe all agents to their topics
for agent in request.agents:
if agent.subscribe_topic:
bus.subscribe(agent.subscribe_topic, agent)
yield create_event("agent_subscribed", agent=agent.title, topic=agent.subscribe_topic)
# Subscribe data sources to their topics
for datasource in request.data_sources:
if datasource.subscribe_topic:
bus.subscribe_datasource(datasource.subscribe_topic, datasource)
yield create_event("datasource_subscribed", datasource=datasource.label, topic=datasource.subscribe_topic)
# Publish START message
start_message = request.user_question if request.user_question else "System initialized"
bus.publish("START", start_message)
yield create_event("message_published", topic="START", content=start_message)
# Process messages in the bus
processed_topics = set()
max_iterations = 20 # Prevent infinite loops
iteration = 0
while iteration < max_iterations:
iteration += 1
# Find topics that have messages but haven't been processed
topics_to_process = [topic for topic in bus.messages.keys() if topic not in processed_topics]
if not topics_to_process:
break
for topic in topics_to_process:
subscribers = bus.get_subscribers(topic)
datasource_subscribers = bus.get_datasource_subscribers(topic)
if not subscribers and not datasource_subscribers:
yield create_event("no_subscribers", topic=topic)
processed_topics.add(topic)
continue
message_content = bus.get_message(topic)
# Update data sources that subscribe to this topic
for datasource in datasource_subscribers:
datasource.content = message_content
yield create_event("datasource_updated",
datasource=datasource.label,
topic=topic,
content=message_content)
# Process agents that subscribe to this topic
for agent in subscribers:
yield create_event("agent_triggered", agent=agent.title, topic=topic)
yield create_event("agent_processing", agent=agent.title)
yield create_event("agent_input", content=message_content)
# Execute agent
try:
result, entities, analyzed_text = await execute_agent(agent, message_content, request.data_sources, request.user_question)
yield create_event("agent_output", content=result)
# Special handling for SQL agents
if is_sql_agent(agent.model) and analyzed_text:
# analyzed_text contains the SQL query that was executed
result_dict = json.loads(result) if not result.startswith("{\"error\"") else {"error": "Query failed"}
if "row_count" in result_dict:
yield create_event("sql_result",
sql=analyzed_text,
rows=result_dict["row_count"])
# If agent wants to show result, send it to frontend
if agent.show_result:
yield create_event("show_result", agent=agent.title, content=result)
# If this is an NER agent with entities, also send formatted NER result
if entities and is_ner_model(agent.model) and analyzed_text:
formatted_text = format_ner_result(analyzed_text, entities)
yield create_event("ner_result", agent=agent.title, formatted_text=formatted_text)
# Publish result to agent's publish topic (if specified)
if agent.publish_topic:
bus.publish(agent.publish_topic, result)
yield create_event("message_published", topic=agent.publish_topic, content=result)
yield create_event("agent_completed", agent=agent.title)
except Exception as e:
yield create_event("error", message=f"Agent {agent.title} failed: {str(e)}")
processed_topics.add(topic)
yield create_event("execution_complete")
except Exception as e:
yield create_event("error", message=str(e))
@app.on_event("startup")
def startup_event():
global con, db_ready
print("Initializing database...")
con = init_db()
# Register SQL extensions
register_extensions(con)
db_ready = True
print("Database ready.")
@app.get("/")
async def root():
"""Serve the main web interface"""
index_file = static_dir / "index.html"
if index_file.exists():
return FileResponse(index_file)
return {"message": "Place index.html in the static/ directory"}
@app.post("/execute")
async def execute(request: ExecutionRequest):
"""Execute the pub/sub agent system with streaming logs"""
return StreamingResponse(
execute_pipeline(request),
media_type="text/event-stream"
)
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/status")
def status():
return {"ready": db_ready}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)