File size: 16,540 Bytes
1d41764 30f499c 1d41764 12a6aaa 16d3318 ff6e4be 1d41764 16d3318 1d41764 ce59284 81f916b ce59284 1d41764 ce59284 1d41764 ce59284 e57db30 1d41764 16d3318 1d41764 81f916b 1d41764 81f916b 1d41764 be34b9e 1d41764 be34b9e 1d41764 81f916b 1d41764 be34b9e 1d41764 be34b9e 1d41764 be34b9e 81f916b 1d41764 ff6e4be 30f499c 16d3318 30f499c 1d41764 ff6e4be be34b9e ff6e4be 30f499c ff6e4be 81f916b ff6e4be 1d41764 ff6e4be 16d3318 ff6e4be 4c96155 ff6e4be 4c96155 ff6e4be 4c96155 ff6e4be 30f499c 99c42aa 1d41764 16d3318 10fc8ae 16d3318 99c42aa 30f499c 99c42aa ff6e4be 99c42aa 30f499c ff6e4be 99c42aa 30f499c 99c42aa 1d41764 81f916b 1d41764 e57db30 1d41764 81f916b 1d41764 81f916b 1d41764 81f916b 1d41764 99c42aa 1d41764 16d3318 ce59284 30f499c 99c42aa 30f499c ce59284 1d41764 16d3318 1d41764 16d3318 1d41764 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 | from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
from typing import List, Dict, Optional, AsyncGenerator
import json
import asyncio
from pathlib import Path
import os
import re
from database import init_db
from db_extensions import register_extensions
# Import transformers for NER
try:
from transformers import pipeline
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("Warning: transformers not available, NER models will not work")
app = FastAPI(title="Pub/Sub Multi-Agent System")
db_ready = False
con = None
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
static_dir = Path(__file__).parent / "static"
static_dir.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Models
class DataSource(BaseModel):
label: str
content: str
subscribe_topic: Optional[str] = None
class Agent(BaseModel):
title: str
prompt: str
model: str
subscribe_topic: str
publish_topic: Optional[str] = None
show_result: bool = False
class ExecutionRequest(BaseModel):
data_sources: List[DataSource]
user_question: str = ""
agents: List[Agent]
class Query(BaseModel):
sql: str
# Pub/Sub Bus
class MessageBus:
def __init__(self):
self.subscribers: Dict[str, List[Agent]] = {}
self.datasource_subscribers: Dict[str, List[DataSource]] = {}
self.messages: Dict[str, str] = {}
def reset(self):
"""Reset the bus for a new execution"""
self.subscribers = {}
self.datasource_subscribers = {}
self.messages = {}
def _normalize_topic(self, topic: str) -> str:
"""Normalize topic to lowercase for case-insensitive matching"""
return topic.lower().strip()
def subscribe(self, topic: str, agent: Agent):
"""Subscribe an agent to a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
if normalized not in self.subscribers:
self.subscribers[normalized] = []
self.subscribers[normalized].append(agent)
def subscribe_datasource(self, topic: str, datasource: DataSource):
"""Subscribe a data source to a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
if normalized not in self.datasource_subscribers:
self.datasource_subscribers[normalized] = []
self.datasource_subscribers[normalized].append(datasource)
def publish(self, topic: str, content: str):
"""Publish a message to a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
self.messages[normalized] = content
def get_message(self, topic: str) -> Optional[str]:
"""Get message from a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
return self.messages.get(normalized)
def get_subscribers(self, topic: str) -> List[Agent]:
"""Get all subscribers for a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
return self.subscribers.get(normalized, [])
def get_datasource_subscribers(self, topic: str) -> List[DataSource]:
"""Get all data source subscribers for a topic (case insensitive)"""
normalized = self._normalize_topic(topic)
return self.datasource_subscribers.get(normalized, [])
# Stream event helper
def create_event(event_type: str, **kwargs):
data = {"type": event_type, **kwargs}
return f"data: {json.dumps(data)}\n\n"
# Get LLM instance
def get_llm(model_name: str):
return Ollama(model=model_name, temperature=0.1)
# NER pipeline cache
_ner_pipelines = {}
def get_ner_pipeline(model_name: str):
"""Get or create NER pipeline for the specified model"""
if not TRANSFORMERS_AVAILABLE:
raise RuntimeError("transformers package not available")
if model_name not in _ner_pipelines:
print(f"Loading NER model: {model_name}")
_ner_pipelines[model_name] = pipeline(
"ner",
model=model_name,
aggregation_strategy="simple"
)
return _ner_pipelines[model_name]
# Check if model is NER model
def is_ner_model(model_name: str) -> bool:
"""Check if the model is an NER model"""
ner_models = [
"samrawal/bert-base-uncased_clinical-ner",
"OpenMed/OpenMed-NER-AnatomyDetect-BioPatient-108M"
]
return model_name in ner_models
# Check if model is SQL agent
def is_sql_agent(model_name: str) -> bool:
"""Check if the model is SQL agent"""
return model_name.upper() == "SQL"
# Format NER output for display
def format_ner_result(text: str, entities: List[Dict]) -> str:
"""Format NER entities for human-readable display"""
if not entities:
return text
# Sort entities by start position in reverse to avoid index issues
sorted_entities = sorted(entities, key=lambda x: x['start'], reverse=True)
result = text
for entity in sorted_entities:
start = entity['start']
end = entity['end']
entity_type = entity['entity_group']
original_text = text[start:end]
# Replace entity with labeled version
labeled = f"[{original_text}:{entity_type}]"
result = result[:start] + labeled + result[end:]
return result
# Execute SQL query
def execute_sql_query(sql: str) -> tuple[str, Optional[Dict]]:
"""Execute SQL query and return JSON result + formatted dict"""
if not db_ready:
return {"error": "Database not ready"}
try:
result = con.execute(sql).fetchall()
columns = [desc[0] for desc in con.description]
formatted_result = {
"columns": columns,
"rows": result,
"row_count": len(result)
}
json_output = json.dumps(formatted_result, indent=2)
return json_output, formatted_result
except Exception as e:
error_msg = f"SQL execution failed: {str(e)}"
return json.dumps({"error": error_msg}), None
# Process NER with transformers pipeline
def process_ner(text: str, model_name: str) -> tuple[str, List[Dict]]:
"""Process text with NER pipeline and return JSON + formatted entities"""
try:
ner_pipeline = get_ner_pipeline(model_name)
# Run NER
entities = ner_pipeline(text)
# Convert to our format with proper type conversion
formatted_entities = []
for entity in entities:
formatted_entities.append({
"text": str(entity['word']),
"entity_type": str(entity['entity_group']),
"start": int(entity['start']),
"end": int(entity['end']),
"score": float(entity.get('score', 0.0)) # Convert numpy float32 to Python float
})
# Create JSON output with proper serialization
json_output = json.dumps(formatted_entities, indent=2)
return json_output, formatted_entities
except Exception as e:
error_msg = f"NER processing failed: {str(e)}"
return json.dumps({"error": error_msg}), []
# Execute agent
async def execute_agent(agent: Agent, input_content: str, data_sources: List[DataSource], user_question: str) -> tuple[str, Optional[List[Dict]], Optional[str]]:
"""Execute a single agent with the given input. Returns (result, entities, analyzed_text) where entities is for NER models."""
# Case-insensitive replacement helper
def replace_case_insensitive(text: str, placeholder: str, value: str) -> str:
"""Replace placeholder in text, case insensitive"""
pattern = re.compile(re.escape(placeholder), re.IGNORECASE)
return pattern.sub(value, text)
# Start with the agent's prompt template
prompt_text = agent.prompt if agent.prompt else ""
# Replace standard placeholders (case insensitive)
prompt_text = replace_case_insensitive(prompt_text, "{input}", input_content)
prompt_text = replace_case_insensitive(prompt_text, "{question}", user_question)
# Replace data source placeholders (case insensitive)
for ds in data_sources:
placeholder = "{" + ds.label + "}"
prompt_text = replace_case_insensitive(prompt_text, placeholder, ds.content)
# Check agent type
if is_sql_agent(agent.model):
# SQL agent: rendered prompt IS the SQL query
sql_query = prompt_text.strip()
# If prompt is empty, use input content as SQL
if not sql_query:
sql_query = input_content
# Execute SQL query
json_result, query_result = execute_sql_query(sql_query)
# Return JSON result, query result dict, and the SQL that was executed
return json_result, None, sql_query
elif is_ner_model(agent.model):
# For NER models, the rendered prompt IS the text to analyze
text_to_analyze = prompt_text
# If prompt is empty, use input content directly
if not text_to_analyze.strip():
text_to_analyze = input_content
# Process with NER pipeline
json_result, entities = process_ner(text_to_analyze, agent.model)
# Return JSON result, entities, and the text that was analyzed
return json_result, entities, text_to_analyze
else:
# Regular LLM processing
llm = get_llm(agent.model)
# Invoke LLM with the rendered prompt
result = llm.invoke(prompt_text)
return (result if isinstance(result, str) else str(result)), None, None
# Main execution pipeline
async def execute_pipeline(request: ExecutionRequest) -> AsyncGenerator[str, None]:
try:
bus = MessageBus()
yield create_event("bus_init")
# Reset and configure subscriptions
bus.reset()
# Subscribe all agents to their topics
for agent in request.agents:
if agent.subscribe_topic:
bus.subscribe(agent.subscribe_topic, agent)
yield create_event("agent_subscribed", agent=agent.title, topic=agent.subscribe_topic)
# Subscribe data sources to their topics
for datasource in request.data_sources:
if datasource.subscribe_topic:
bus.subscribe_datasource(datasource.subscribe_topic, datasource)
yield create_event("datasource_subscribed", datasource=datasource.label, topic=datasource.subscribe_topic)
# Publish START message
start_message = request.user_question if request.user_question else "System initialized"
bus.publish("START", start_message)
yield create_event("message_published", topic="START", content=start_message)
# Process messages in the bus
processed_topics = set()
max_iterations = 20 # Prevent infinite loops
iteration = 0
while iteration < max_iterations:
iteration += 1
# Find topics that have messages but haven't been processed
topics_to_process = [topic for topic in bus.messages.keys() if topic not in processed_topics]
if not topics_to_process:
break
for topic in topics_to_process:
subscribers = bus.get_subscribers(topic)
datasource_subscribers = bus.get_datasource_subscribers(topic)
if not subscribers and not datasource_subscribers:
yield create_event("no_subscribers", topic=topic)
processed_topics.add(topic)
continue
message_content = bus.get_message(topic)
# Update data sources that subscribe to this topic
for datasource in datasource_subscribers:
datasource.content = message_content
yield create_event("datasource_updated",
datasource=datasource.label,
topic=topic,
content=message_content)
# Process agents that subscribe to this topic
for agent in subscribers:
yield create_event("agent_triggered", agent=agent.title, topic=topic)
yield create_event("agent_processing", agent=agent.title)
yield create_event("agent_input", content=message_content)
# Execute agent
try:
result, entities, analyzed_text = await execute_agent(agent, message_content, request.data_sources, request.user_question)
yield create_event("agent_output", content=result)
# Special handling for SQL agents
if is_sql_agent(agent.model) and analyzed_text:
# analyzed_text contains the SQL query that was executed
result_dict = json.loads(result) if not result.startswith("{\"error\"") else {"error": "Query failed"}
if "row_count" in result_dict:
yield create_event("sql_result",
sql=analyzed_text,
rows=result_dict["row_count"])
# If agent wants to show result, send it to frontend
if agent.show_result:
yield create_event("show_result", agent=agent.title, content=result)
# If this is an NER agent with entities, also send formatted NER result
if entities and is_ner_model(agent.model) and analyzed_text:
formatted_text = format_ner_result(analyzed_text, entities)
yield create_event("ner_result", agent=agent.title, formatted_text=formatted_text)
# Publish result to agent's publish topic (if specified)
if agent.publish_topic:
bus.publish(agent.publish_topic, result)
yield create_event("message_published", topic=agent.publish_topic, content=result)
yield create_event("agent_completed", agent=agent.title)
except Exception as e:
yield create_event("error", message=f"Agent {agent.title} failed: {str(e)}")
processed_topics.add(topic)
yield create_event("execution_complete")
except Exception as e:
yield create_event("error", message=str(e))
@app.on_event("startup")
def startup_event():
global con, db_ready
print("Initializing database...")
con = init_db()
# Register SQL extensions
register_extensions(con)
db_ready = True
print("Database ready.")
@app.get("/")
async def root():
"""Serve the main web interface"""
index_file = static_dir / "index.html"
if index_file.exists():
return FileResponse(index_file)
return {"message": "Place index.html in the static/ directory"}
@app.post("/execute")
async def execute(request: ExecutionRequest):
"""Execute the pub/sub agent system with streaming logs"""
return StreamingResponse(
execute_pipeline(request),
media_type="text/event-stream"
)
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/status")
def status():
return {"ready": db_ready}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|