|
|
"""Conversation Service for managing active AI-to-AI conversations. |
|
|
|
|
|
This service acts as the bridge between the WebSocket interface and the |
|
|
ConversationManager. It handles the lifecycle of conversations, manages |
|
|
active instances, and coordinates message streaming to connected clients. |
|
|
|
|
|
Classes: |
|
|
ConversationService: Main service for conversation management |
|
|
ConversationInfo: Data class for conversation metadata |
|
|
|
|
|
Example: |
|
|
service = ConversationService(websocket_manager) |
|
|
conversation_id = await service.start_conversation( |
|
|
surveyor_id="surveyor_001", |
|
|
patient_id="patient_001" |
|
|
) |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from typing import Dict, Optional, Any, List |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import pysbd |
|
|
|
|
|
|
|
|
BACKEND_DIR = Path(__file__).resolve().parents[2] |
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[3] |
|
|
for path in (BACKEND_DIR, PROJECT_ROOT): |
|
|
if str(path) not in sys.path: |
|
|
sys.path.insert(0, str(path)) |
|
|
|
|
|
from config.settings import AppSettings, get_settings |
|
|
from backend.core.conversation_manager import ConversationManager |
|
|
from backend.core.llm_client import create_llm_client |
|
|
from backend.core.persona_system import get_persona_system |
|
|
from .conversation_ws import ConnectionManager |
|
|
from .storage_service import get_run_store |
|
|
from .storage_service import get_persona_store |
|
|
from backend.storage import RunRecord |
|
|
from backend.core.surveyor_knobs import compile_surveyor_attributes_overlay, compile_question_bank_overlay |
|
|
from backend.core.patient_knobs import compile_patient_attributes_overlay |
|
|
from backend.core.analysis_knobs import compile_analysis_rules_block |
|
|
from backend.core.universal_prompts import DEFAULT_PATIENT_SYSTEM_PROMPT, DEFAULT_SURVEYOR_SYSTEM_PROMPT |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
_SENTENCE_SEGMENTER = pysbd.Segmenter(language="en", clean=False) |
|
|
|
|
|
SURVEYOR_MAX_TOKENS = 140 |
|
|
PATIENT_MAX_TOKENS = 220 |
|
|
|
|
|
|
|
|
def _split_sentences(text: str) -> List[str]: |
|
|
normalized = " ".join((text or "").split()) |
|
|
if not normalized: |
|
|
return [] |
|
|
try: |
|
|
sentences = [s.strip() for s in _SENTENCE_SEGMENTER.segment(normalized) if s.strip()] |
|
|
except Exception: |
|
|
sentences = [] |
|
|
return sentences or [normalized] |
|
|
|
|
|
|
|
|
def _normalize_confidence(value: Any) -> Optional[float]: |
|
|
"""Normalize confidence values to a float in [0, 1].""" |
|
|
try: |
|
|
confidence = float(value) |
|
|
except (TypeError, ValueError): |
|
|
return None |
|
|
|
|
|
if confidence < 0: |
|
|
confidence = 0.0 |
|
|
|
|
|
if confidence > 1.0: |
|
|
|
|
|
if confidence <= 100.0: |
|
|
confidence = confidence / 100.0 |
|
|
else: |
|
|
confidence = 1.0 |
|
|
|
|
|
return max(0.0, min(1.0, confidence)) |
|
|
|
|
|
|
|
|
async def run_resource_agent_analysis( |
|
|
*, |
|
|
transcript: List[Dict[str, Any]], |
|
|
llm_backend: str, |
|
|
host: str, |
|
|
model: str, |
|
|
settings: AppSettings, |
|
|
analysis_attributes: Optional[List[str]] = None, |
|
|
analysis_system_prompt: Optional[str] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run the resource agent analysis on an in-memory transcript and return parsed JSON. |
|
|
|
|
|
Shared by the live conversation flow and ad-hoc analysis endpoints. |
|
|
""" |
|
|
llm_params: Dict[str, Any] = { |
|
|
"timeout": settings.llm.timeout, |
|
|
"max_retries": settings.llm.max_retries, |
|
|
"retry_delay": settings.llm.retry_delay, |
|
|
} |
|
|
if settings.llm.api_key: |
|
|
llm_params["api_key"] = settings.llm.api_key |
|
|
if settings.llm.site_url: |
|
|
llm_params["site_url"] = settings.llm.site_url |
|
|
if settings.llm.app_name: |
|
|
llm_params["app_name"] = settings.llm.app_name |
|
|
|
|
|
client = create_llm_client( |
|
|
llm_backend, |
|
|
host=host, |
|
|
model=model, |
|
|
**llm_params, |
|
|
) |
|
|
|
|
|
schema_version = "7" |
|
|
analysis_prompt_version = "v2" |
|
|
|
|
|
evidence_catalog: Dict[str, Dict[str, Any]] = {} |
|
|
for message in transcript: |
|
|
message_index = message.get("index") |
|
|
content = message.get("content", "") or "" |
|
|
if not isinstance(message_index, int): |
|
|
continue |
|
|
for sentence_index, sentence in enumerate(_split_sentences(content)): |
|
|
evidence_id = f"m{message_index}s{sentence_index}" |
|
|
evidence_catalog[evidence_id] = { |
|
|
"message_index": message_index, |
|
|
"sentence_index": sentence_index, |
|
|
"text": sentence, |
|
|
} |
|
|
|
|
|
base = (analysis_system_prompt or "").strip() |
|
|
if not base: |
|
|
base = ( |
|
|
"You are a clinical research 'resource agent'. You are given a transcript of a simulated " |
|
|
"health survey conversation between a surveyor and a patient. Your task is to extract " |
|
|
"post-hoc insights as strict JSON for a UI." |
|
|
) |
|
|
system_prompt = (base + "\n\n" + compile_analysis_rules_block(analysis_attributes)).strip() |
|
|
|
|
|
evidence_catalog_json = json.dumps(evidence_catalog, ensure_ascii=False) |
|
|
user_prompt = ( |
|
|
"Evidence catalog (JSON object mapping evidence_id -> sentence):\n" |
|
|
f"{evidence_catalog_json}\n\n" |
|
|
"Return JSON matching this schema:\n" |
|
|
"{\n" |
|
|
f" \"schema_version\": \"{schema_version}\",\n" |
|
|
f" \"analysis_prompt_version\": \"{analysis_prompt_version}\",\n" |
|
|
" \"health_situations\": [\n" |
|
|
" {\n" |
|
|
" \"code\": string, // 1-3 word label\n" |
|
|
" \"summary\": string,\n" |
|
|
" \"evidence\": [ {\"evidence_id\": string} ],\n" |
|
|
" \"confidence\": number // 0..1\n" |
|
|
" }\n" |
|
|
" ],\n" |
|
|
" \"care_experience\": {\n" |
|
|
" \"positive\": {\n" |
|
|
" \"summary\": string,\n" |
|
|
" \"reasons\": [string],\n" |
|
|
" \"evidence\": [ {\"evidence_id\": string} ],\n" |
|
|
" \"confidence\": number // 0..1\n" |
|
|
" },\n" |
|
|
" \"mixed\": {\n" |
|
|
" \"summary\": string,\n" |
|
|
" \"reasons\": [string],\n" |
|
|
" \"evidence\": [ {\"evidence_id\": string} ],\n" |
|
|
" \"confidence\": number // 0..1\n" |
|
|
" },\n" |
|
|
" \"negative\": {\n" |
|
|
" \"summary\": string,\n" |
|
|
" \"reasons\": [string],\n" |
|
|
" \"evidence\": [ {\"evidence_id\": string} ],\n" |
|
|
" \"confidence\": number // 0..1\n" |
|
|
" },\n" |
|
|
" \"neutral\": {\n" |
|
|
" \"summary\": string,\n" |
|
|
" \"reasons\": [string],\n" |
|
|
" \"evidence\": [ {\"evidence_id\": string} ],\n" |
|
|
" \"confidence\": number // 0..1\n" |
|
|
" }\n" |
|
|
" }\n" |
|
|
" \"top_down_codes\": {\n" |
|
|
" \"symptoms_concerns\": [ {\"code\": string, \"summary\": string, \"evidence\": [ {\"evidence_id\": string} ], \"confidence\": number // 0..1 } ],\n" |
|
|
" \"daily_management\": [ {\"code\": string, \"summary\": string, \"evidence\": [ {\"evidence_id\": string} ], \"confidence\": number // 0..1 } ],\n" |
|
|
" \"barriers_constraints\": [ {\"code\": string, \"summary\": string, \"evidence\": [ {\"evidence_id\": string} ], \"confidence\": number // 0..1 } ],\n" |
|
|
" \"support_resources\": [ {\"code\": string, \"summary\": string, \"evidence\": [ {\"evidence_id\": string} ], \"confidence\": number // 0..1 } ]\n" |
|
|
" }\n" |
|
|
"}\n" |
|
|
) |
|
|
|
|
|
try: |
|
|
raw = await client.generate(prompt=user_prompt, system_prompt=system_prompt, temperature=0.2) |
|
|
parsed = json.loads(raw) |
|
|
parsed["evidence_catalog"] = evidence_catalog |
|
|
parsed["analysis_prompt_version"] = analysis_prompt_version |
|
|
|
|
|
for item in parsed.get("health_situations", []) or []: |
|
|
normalized = _normalize_confidence(item.get("confidence")) |
|
|
if normalized is not None: |
|
|
item["confidence"] = normalized |
|
|
|
|
|
care_experience = parsed.get("care_experience") or {} |
|
|
for key in ("positive", "mixed", "negative", "neutral"): |
|
|
box = care_experience.get(key) |
|
|
if isinstance(box, dict): |
|
|
normalized = _normalize_confidence(box.get("confidence")) |
|
|
if normalized is not None: |
|
|
box["confidence"] = normalized |
|
|
|
|
|
top_down_codes = parsed.get("top_down_codes") or {} |
|
|
for key in ("symptoms_concerns", "daily_management", "barriers_constraints", "support_resources"): |
|
|
items = top_down_codes.get(key) or [] |
|
|
if not isinstance(items, list): |
|
|
continue |
|
|
for item in items: |
|
|
if not isinstance(item, dict): |
|
|
continue |
|
|
normalized = _normalize_confidence(item.get("confidence")) |
|
|
if normalized is not None: |
|
|
item["confidence"] = normalized |
|
|
|
|
|
return parsed |
|
|
finally: |
|
|
try: |
|
|
await client.close() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
class ConversationStatus(Enum): |
|
|
"""Status of managed conversations.""" |
|
|
STARTING = "starting" |
|
|
RUNNING = "running" |
|
|
PAUSED = "paused" |
|
|
STOPPING = "stopping" |
|
|
COMPLETED = "completed" |
|
|
ERROR = "error" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ConversationInfo: |
|
|
"""Information about an active conversation.""" |
|
|
conversation_id: str |
|
|
surveyor_persona_id: str |
|
|
patient_persona_id: str |
|
|
host: str |
|
|
model: str |
|
|
llm_backend: str |
|
|
status: ConversationStatus |
|
|
created_at: datetime |
|
|
message_count: int = 0 |
|
|
task: Optional[asyncio.Task] = None |
|
|
stop_requested: bool = False |
|
|
surveyor_system_prompt: str = "" |
|
|
patient_system_prompt: str = "" |
|
|
analysis_system_prompt: str = "" |
|
|
analysis_attributes: List[str] = field(default_factory=list) |
|
|
patient_attributes: List[str] = field(default_factory=list) |
|
|
surveyor_attributes: List[str] = field(default_factory=list) |
|
|
surveyor_question_bank: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HumanChatInfo: |
|
|
"""Information about an active human-to-surveyor chat session.""" |
|
|
|
|
|
conversation_id: str |
|
|
surveyor_persona_id: str |
|
|
patient_persona_id: str |
|
|
host: str |
|
|
model: str |
|
|
llm_backend: str |
|
|
status: ConversationStatus |
|
|
created_at: datetime |
|
|
stop_requested: bool = False |
|
|
surveyor_system_prompt: str = "" |
|
|
patient_system_prompt: str = "" |
|
|
analysis_system_prompt: str = "" |
|
|
analysis_attributes: List[str] = field(default_factory=list) |
|
|
patient_attributes: List[str] = field(default_factory=list) |
|
|
surveyor_attributes: List[str] = field(default_factory=list) |
|
|
surveyor_question_bank: Optional[str] = None |
|
|
ai_role: str = "surveyor" |
|
|
asked_question_ids: List[str] = field(default_factory=list) |
|
|
lock: asyncio.Lock = field(default_factory=asyncio.Lock) |
|
|
client: Any = None |
|
|
|
|
|
|
|
|
class ConversationService: |
|
|
"""Service for managing AI-to-AI conversation instances. |
|
|
|
|
|
This service coordinates between the ConversationManager and WebSocket |
|
|
infrastructure to provide real-time conversation streaming to web clients. |
|
|
|
|
|
Attributes: |
|
|
websocket_manager: WebSocket connection manager for broadcasting |
|
|
persona_system: Persona system for loading personas |
|
|
active_conversations: Dict of active conversation instances |
|
|
settings: Shared application settings |
|
|
""" |
|
|
|
|
|
def __init__(self, websocket_manager: ConnectionManager, settings: Optional[AppSettings] = None): |
|
|
"""Initialize conversation service. |
|
|
|
|
|
Args: |
|
|
websocket_manager: WebSocket manager for message broadcasting |
|
|
settings: Shared application settings (optional) |
|
|
""" |
|
|
self.websocket_manager = websocket_manager |
|
|
self.persona_system = get_persona_system() |
|
|
self.active_conversations: Dict[str, ConversationInfo] = {} |
|
|
self.active_human_chats: Dict[str, HumanChatInfo] = {} |
|
|
self.transcripts: Dict[str, List[Dict[str, Any]]] = {} |
|
|
self.settings = settings or get_settings() |
|
|
|
|
|
def _persona_question_bank(self, persona: Dict[str, Any]) -> Optional[str]: |
|
|
items = persona.get("question_bank_items") |
|
|
lines: List[str] = [] |
|
|
if isinstance(items, list): |
|
|
for item in items: |
|
|
if isinstance(item, str) and item.strip(): |
|
|
lines.append(item.strip()) |
|
|
elif isinstance(item, dict): |
|
|
text = item.get("text") |
|
|
if isinstance(text, str) and text.strip(): |
|
|
lines.append(text.strip()) |
|
|
raw = "\n".join(lines).strip() |
|
|
return raw or None |
|
|
|
|
|
def _persona_attributes(self, persona: Dict[str, Any]) -> List[str]: |
|
|
attrs = persona.get("attributes") |
|
|
if not isinstance(attrs, list): |
|
|
return [] |
|
|
return [s.strip() for s in attrs if isinstance(s, str) and s.strip()] |
|
|
|
|
|
async def start_human_chat( |
|
|
self, |
|
|
conversation_id: str, |
|
|
surveyor_persona_id: str, |
|
|
patient_persona_id: str, |
|
|
host: Optional[str] = None, |
|
|
model: Optional[str] = None, |
|
|
patient_attributes: Optional[List[str]] = None, |
|
|
surveyor_system_prompt: Optional[str] = None, |
|
|
patient_system_prompt: Optional[str] = None, |
|
|
analysis_attributes: Optional[List[str]] = None, |
|
|
surveyor_attributes: Optional[List[str]] = None, |
|
|
surveyor_question_bank: Optional[str] = None, |
|
|
ai_role: Optional[str] = None, |
|
|
) -> bool: |
|
|
"""Start a new human-to-surveyor chat session.""" |
|
|
if conversation_id in self.active_conversations or conversation_id in self.active_human_chats: |
|
|
logger.warning(f"Conversation {conversation_id} already exists") |
|
|
return False |
|
|
|
|
|
surveyor_persona = self.persona_system.get_persona(surveyor_persona_id) |
|
|
patient_persona = self.persona_system.get_persona(patient_persona_id) |
|
|
if not surveyor_persona or not patient_persona: |
|
|
await self._send_error(conversation_id, "Invalid persona IDs") |
|
|
return False |
|
|
|
|
|
resolved_host = host or self.settings.llm.host |
|
|
resolved_model = model or self.settings.llm.model |
|
|
resolved_backend = self.settings.llm.backend |
|
|
|
|
|
resolved_ai_role = ai_role if ai_role in ("surveyor", "patient") else "surveyor" |
|
|
|
|
|
|
|
|
store = get_persona_store() |
|
|
sp = await store.get_setting("surveyor_system_prompt") |
|
|
pp = await store.get_setting("patient_system_prompt") |
|
|
asp = await store.get_setting("analysis_system_prompt") |
|
|
ap = await store.get_setting("analysis_attributes") |
|
|
resolved_surveyor_prompt = sp if isinstance(sp, str) and sp.strip() else DEFAULT_SURVEYOR_SYSTEM_PROMPT |
|
|
resolved_patient_prompt = pp if isinstance(pp, str) and pp.strip() else DEFAULT_PATIENT_SYSTEM_PROMPT |
|
|
resolved_analysis_prompt = asp if isinstance(asp, str) and asp.strip() else "" |
|
|
resolved_analysis_attrs = [s.strip() for s in ap if isinstance(ap, str) and s.strip()] if isinstance(ap, list) else [] |
|
|
|
|
|
chat_info = HumanChatInfo( |
|
|
conversation_id=conversation_id, |
|
|
surveyor_persona_id=surveyor_persona_id, |
|
|
patient_persona_id=patient_persona_id, |
|
|
host=resolved_host, |
|
|
model=resolved_model, |
|
|
llm_backend=resolved_backend, |
|
|
surveyor_system_prompt=resolved_surveyor_prompt, |
|
|
patient_system_prompt=resolved_patient_prompt, |
|
|
analysis_system_prompt=resolved_analysis_prompt, |
|
|
analysis_attributes=resolved_analysis_attrs, |
|
|
patient_attributes=self._persona_attributes(patient_persona), |
|
|
surveyor_attributes=self._persona_attributes(surveyor_persona), |
|
|
surveyor_question_bank=self._persona_question_bank(surveyor_persona), |
|
|
ai_role=resolved_ai_role, |
|
|
status=ConversationStatus.STARTING, |
|
|
created_at=datetime.now(), |
|
|
) |
|
|
|
|
|
llm_parameters = self._build_llm_parameters() |
|
|
client_kwargs = {"host": resolved_host, "model": resolved_model} |
|
|
client_kwargs.update(llm_parameters) |
|
|
chat_info.client = create_llm_client(resolved_backend, **client_kwargs) |
|
|
|
|
|
self.active_human_chats[conversation_id] = chat_info |
|
|
self.transcripts[conversation_id] = [] |
|
|
|
|
|
await self._send_status_update(conversation_id, ConversationStatus.STARTING) |
|
|
await self._send_status_update(conversation_id, ConversationStatus.RUNNING) |
|
|
|
|
|
|
|
|
if chat_info.ai_role == "surveyor": |
|
|
try: |
|
|
greeting = await self._generate_human_chat_surveyor_message( |
|
|
chat_info, |
|
|
transcript=[], |
|
|
user_prompt=( |
|
|
"Please greet the patient briefly and ask your first survey question." |
|
|
), |
|
|
) |
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="surveyor", |
|
|
persona=surveyor_persona.get("name", "Surveyor"), |
|
|
content=greeting, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to generate human-chat greeting: {e}") |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
try: |
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="system", |
|
|
persona="System", |
|
|
content="You call the patient, and they picked up the phone.", |
|
|
) |
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="patient", |
|
|
persona=patient_persona.get("name", "Patient"), |
|
|
content="Hello?", |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to inject human-chat AI-patient starter messages: {e}") |
|
|
|
|
|
return True |
|
|
|
|
|
async def human_chat_message(self, conversation_id: str, text: str) -> None: |
|
|
"""Process a human patient message and generate a surveyor reply.""" |
|
|
chat_info = self.active_human_chats.get(conversation_id) |
|
|
if not chat_info: |
|
|
await self._send_error(conversation_id, "Human chat not found") |
|
|
return |
|
|
|
|
|
async with chat_info.lock: |
|
|
if chat_info.stop_requested or chat_info.status in (ConversationStatus.COMPLETED, ConversationStatus.ERROR): |
|
|
return |
|
|
|
|
|
patient_persona = self.persona_system.get_persona(chat_info.patient_persona_id) or {} |
|
|
surveyor_persona = self.persona_system.get_persona(chat_info.surveyor_persona_id) or {} |
|
|
|
|
|
transcript = self.transcripts.get(conversation_id, []) |
|
|
|
|
|
if chat_info.ai_role == "patient": |
|
|
|
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="surveyor", |
|
|
persona=f"{surveyor_persona.get('name', 'Surveyor')} (Human)", |
|
|
content=text, |
|
|
) |
|
|
|
|
|
transcript = self.transcripts.get(conversation_id, []) |
|
|
last_surveyor_msg = next((m for m in reversed(transcript) if m.get("role") == "surveyor"), None) |
|
|
last_text = (last_surveyor_msg or {}).get("content", text) |
|
|
reply = await self._generate_human_chat_patient_message( |
|
|
chat_info, |
|
|
transcript=transcript, |
|
|
user_prompt=( |
|
|
f"The interviewer just said: '{last_text}'. " |
|
|
"Please respond naturally as your persona would." |
|
|
), |
|
|
) |
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="patient", |
|
|
persona=patient_persona.get("name", "Patient"), |
|
|
content=reply, |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
patient_label = patient_persona.get("name", "Patient") |
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="patient", |
|
|
persona=f"{patient_label} (Human)", |
|
|
content=text, |
|
|
) |
|
|
|
|
|
transcript = self.transcripts.get(conversation_id, []) |
|
|
reply = await self._generate_human_chat_surveyor_message( |
|
|
chat_info, |
|
|
transcript=transcript, |
|
|
user_prompt=( |
|
|
f"The patient just said: '{text}'. Respond with a brief acknowledgment and ask an appropriate follow-up question." |
|
|
), |
|
|
) |
|
|
await self._append_and_broadcast_transcript( |
|
|
conversation_id=conversation_id, |
|
|
role="surveyor", |
|
|
persona=surveyor_persona.get("name", "Surveyor"), |
|
|
content=reply, |
|
|
) |
|
|
|
|
|
async def end_human_chat(self, conversation_id: str) -> bool: |
|
|
"""End a human-to-surveyor chat session and run analysis.""" |
|
|
chat_info = self.active_human_chats.get(conversation_id) |
|
|
if not chat_info: |
|
|
return False |
|
|
|
|
|
async with chat_info.lock: |
|
|
if chat_info.status == ConversationStatus.COMPLETED: |
|
|
return True |
|
|
chat_info.status = ConversationStatus.COMPLETED |
|
|
await self._send_status_update(conversation_id, ConversationStatus.COMPLETED) |
|
|
|
|
|
asked_ids = None |
|
|
try: |
|
|
asked_ids = list(getattr(chat_info, "asked_question_ids", None) or []) |
|
|
except Exception: |
|
|
asked_ids = None |
|
|
await self._run_resource_agent(conversation_id, asked_question_ids=asked_ids) |
|
|
|
|
|
self.active_human_chats.pop(conversation_id, None) |
|
|
self.transcripts.pop(conversation_id, None) |
|
|
|
|
|
try: |
|
|
if chat_info.client is not None: |
|
|
await chat_info.client.close() |
|
|
except Exception: |
|
|
pass |
|
|
return True |
|
|
|
|
|
async def _append_and_broadcast_transcript( |
|
|
self, |
|
|
*, |
|
|
conversation_id: str, |
|
|
role: str, |
|
|
persona: str, |
|
|
content: str, |
|
|
) -> None: |
|
|
timestamp = datetime.now().isoformat() |
|
|
idx = len(self.transcripts.setdefault(conversation_id, [])) |
|
|
self.transcripts[conversation_id].append( |
|
|
{ |
|
|
"index": idx, |
|
|
"role": role, |
|
|
"persona": persona, |
|
|
"content": content, |
|
|
"timestamp": timestamp, |
|
|
} |
|
|
) |
|
|
await self.websocket_manager.send_to_conversation( |
|
|
conversation_id, |
|
|
{ |
|
|
"type": "conversation_message", |
|
|
"conversation_id": conversation_id, |
|
|
"role": role, |
|
|
"persona": persona, |
|
|
"content": content, |
|
|
"timestamp": timestamp, |
|
|
}, |
|
|
) |
|
|
|
|
|
async def _generate_human_chat_surveyor_message( |
|
|
self, |
|
|
chat_info: HumanChatInfo, |
|
|
*, |
|
|
transcript: List[Dict[str, Any]], |
|
|
user_prompt: str, |
|
|
) -> str: |
|
|
conversation_history = [ |
|
|
{"role": "assistant" if msg.get("role") == "surveyor" else "user", "content": msg.get("content", "")} |
|
|
for msg in (transcript or []) |
|
|
] |
|
|
|
|
|
system_prompt, prompt_with_history = self.persona_system.build_conversation_prompt( |
|
|
persona_id=chat_info.surveyor_persona_id, |
|
|
conversation_history=conversation_history, |
|
|
user_prompt=user_prompt, |
|
|
base_system_prompt=getattr(chat_info, "surveyor_system_prompt", None), |
|
|
) |
|
|
|
|
|
qb = compile_question_bank_overlay(chat_info.surveyor_question_bank) |
|
|
if qb: |
|
|
system_prompt = (system_prompt + "\n\n" + qb).strip() |
|
|
|
|
|
attrs = compile_surveyor_attributes_overlay(chat_info.surveyor_attributes) |
|
|
if attrs: |
|
|
system_prompt = (system_prompt + "\n\n" + attrs).strip() |
|
|
|
|
|
patient_persona = self.persona_system.get_persona(chat_info.patient_persona_id) or {} |
|
|
try: |
|
|
patient_context = self.persona_system.prompt_builder.build_system_prompt(patient_persona) |
|
|
except Exception: |
|
|
patient_context = patient_persona.get("system_prompt", "") or "" |
|
|
|
|
|
patient_context = (patient_context or "").strip() |
|
|
pat_lines = [s.strip() for s in (chat_info.patient_attributes or []) if isinstance(s, str) and s.strip()] |
|
|
if pat_lines: |
|
|
bullets = "\n".join(f"- {line}" for line in pat_lines) |
|
|
patient_context = (patient_context + "\n\nPatient attributes (for context only):\n" + bullets).strip() |
|
|
if patient_context: |
|
|
system_prompt = (system_prompt + "\n\nPatient background (for context only):\n" + patient_context).strip() |
|
|
|
|
|
final_prompt = prompt_with_history |
|
|
if chat_info.surveyor_question_bank: |
|
|
final_prompt = ( |
|
|
f"{prompt_with_history}\n\n" |
|
|
"You must pick exactly ONE question from the question bank that has not been asked yet and fits the flow.\n" |
|
|
f"Already asked question ids: {chat_info.asked_question_ids}\n\n" |
|
|
"Return STRICT JSON only (no markdown):\n" |
|
|
"{\n" |
|
|
" \"selected_question_id\": string, // e.g. \"q01\"\n" |
|
|
" \"message\": string\n" |
|
|
"}\n" |
|
|
) |
|
|
|
|
|
response = await chat_info.client.generate( |
|
|
prompt=final_prompt, |
|
|
system_prompt=system_prompt, |
|
|
max_tokens=SURVEYOR_MAX_TOKENS, |
|
|
temperature=0.4, |
|
|
) |
|
|
cleaned = (response or "").strip() |
|
|
|
|
|
if chat_info.surveyor_question_bank and cleaned: |
|
|
import json |
|
|
try: |
|
|
parsed = json.loads(cleaned) |
|
|
qid = parsed.get("selected_question_id") |
|
|
msg = parsed.get("message") |
|
|
if isinstance(qid, str) and qid and qid not in chat_info.asked_question_ids: |
|
|
chat_info.asked_question_ids.append(qid) |
|
|
if isinstance(msg, str) and msg.strip(): |
|
|
cleaned = msg.strip() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return cleaned or "I apologize—I'm having trouble responding right now. Could you repeat that?" |
|
|
|
|
|
async def _generate_human_chat_patient_message( |
|
|
self, |
|
|
chat_info: HumanChatInfo, |
|
|
*, |
|
|
transcript: List[Dict[str, Any]], |
|
|
user_prompt: str, |
|
|
) -> str: |
|
|
conversation_history = [ |
|
|
{"role": "assistant" if msg.get("role") == "patient" else "user", "content": msg.get("content", "")} |
|
|
for msg in (transcript or []) |
|
|
] |
|
|
|
|
|
system_prompt, prompt_with_history = self.persona_system.build_conversation_prompt( |
|
|
persona_id=chat_info.patient_persona_id, |
|
|
conversation_history=conversation_history, |
|
|
user_prompt=user_prompt, |
|
|
base_system_prompt=getattr(chat_info, "patient_system_prompt", None), |
|
|
) |
|
|
|
|
|
system_prompt = (system_prompt or "").strip() |
|
|
pat_attrs = compile_patient_attributes_overlay(chat_info.patient_attributes) |
|
|
if pat_attrs: |
|
|
system_prompt = (system_prompt + "\n\n" + pat_attrs).strip() |
|
|
|
|
|
response = await chat_info.client.generate( |
|
|
prompt=prompt_with_history, |
|
|
system_prompt=system_prompt, |
|
|
max_tokens=PATIENT_MAX_TOKENS, |
|
|
temperature=0.7, |
|
|
) |
|
|
return (response or "").strip() or "I'm sorry—I'm having trouble responding right now." |
|
|
|
|
|
async def start_conversation(self, |
|
|
conversation_id: str, |
|
|
surveyor_persona_id: str, |
|
|
patient_persona_id: str, |
|
|
host: Optional[str] = None, |
|
|
model: Optional[str] = None, |
|
|
patient_attributes: Optional[List[str]] = None, |
|
|
surveyor_system_prompt: Optional[str] = None, |
|
|
patient_system_prompt: Optional[str] = None, |
|
|
analysis_attributes: Optional[List[str]] = None, |
|
|
surveyor_attributes: Optional[List[str]] = None, |
|
|
surveyor_question_bank: Optional[str] = None) -> bool: |
|
|
"""Start a new AI-to-AI conversation. |
|
|
|
|
|
Args: |
|
|
conversation_id: Unique identifier for the conversation |
|
|
surveyor_persona_id: ID of the surveyor persona |
|
|
patient_persona_id: ID of the patient persona |
|
|
host: Ollama server host |
|
|
model: LLM model to use |
|
|
|
|
|
Returns: |
|
|
True if conversation started successfully |
|
|
""" |
|
|
if conversation_id in self.active_conversations: |
|
|
logger.warning(f"Conversation {conversation_id} already exists") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
surveyors = self.persona_system.list_personas("surveyor") |
|
|
patients = self.persona_system.list_personas("patient") |
|
|
|
|
|
surveyor_persona = next((p for p in surveyors if p.get("id") == surveyor_persona_id), None) |
|
|
patient_persona = next((p for p in patients if p.get("id") == patient_persona_id), None) |
|
|
|
|
|
if not surveyor_persona or not patient_persona: |
|
|
await self._send_error(conversation_id, "Invalid persona IDs") |
|
|
return False |
|
|
|
|
|
|
|
|
resolved_host = host or self.settings.llm.host |
|
|
resolved_model = model or self.settings.llm.model |
|
|
resolved_backend = self.settings.llm.backend |
|
|
|
|
|
store = get_persona_store() |
|
|
sp = await store.get_setting("surveyor_system_prompt") |
|
|
pp = await store.get_setting("patient_system_prompt") |
|
|
asp = await store.get_setting("analysis_system_prompt") |
|
|
ap = await store.get_setting("analysis_attributes") |
|
|
resolved_surveyor_prompt = sp if isinstance(sp, str) and sp.strip() else DEFAULT_SURVEYOR_SYSTEM_PROMPT |
|
|
resolved_patient_prompt = pp if isinstance(pp, str) and pp.strip() else DEFAULT_PATIENT_SYSTEM_PROMPT |
|
|
resolved_analysis_prompt = asp if isinstance(asp, str) and asp.strip() else "" |
|
|
resolved_analysis_attrs = [s.strip() for s in ap if isinstance(ap, str) and s.strip()] if isinstance(ap, list) else [] |
|
|
|
|
|
|
|
|
conv_info = ConversationInfo( |
|
|
conversation_id=conversation_id, |
|
|
surveyor_persona_id=surveyor_persona_id, |
|
|
patient_persona_id=patient_persona_id, |
|
|
host=resolved_host, |
|
|
model=resolved_model, |
|
|
llm_backend=resolved_backend, |
|
|
surveyor_system_prompt=resolved_surveyor_prompt, |
|
|
patient_system_prompt=resolved_patient_prompt, |
|
|
analysis_system_prompt=resolved_analysis_prompt, |
|
|
analysis_attributes=resolved_analysis_attrs, |
|
|
patient_attributes=self._persona_attributes(patient_persona), |
|
|
surveyor_attributes=self._persona_attributes(surveyor_persona), |
|
|
surveyor_question_bank=self._persona_question_bank(surveyor_persona), |
|
|
status=ConversationStatus.STARTING, |
|
|
created_at=datetime.now() |
|
|
) |
|
|
|
|
|
self.active_conversations[conversation_id] = conv_info |
|
|
self.transcripts[conversation_id] = [] |
|
|
|
|
|
|
|
|
await self._send_status_update(conversation_id, ConversationStatus.STARTING) |
|
|
|
|
|
|
|
|
llm_parameters = self._build_llm_parameters() |
|
|
|
|
|
manager = ConversationManager( |
|
|
surveyor_persona=surveyor_persona, |
|
|
patient_persona=patient_persona, |
|
|
host=resolved_host, |
|
|
model=resolved_model, |
|
|
llm_backend=self.settings.llm.backend, |
|
|
llm_parameters=llm_parameters, |
|
|
surveyor_system_prompt=conv_info.surveyor_system_prompt, |
|
|
patient_system_prompt=conv_info.patient_system_prompt, |
|
|
patient_attributes=conv_info.patient_attributes, |
|
|
surveyor_attributes=conv_info.surveyor_attributes, |
|
|
surveyor_question_bank=conv_info.surveyor_question_bank, |
|
|
) |
|
|
|
|
|
|
|
|
conv_info.task = asyncio.create_task( |
|
|
self._stream_conversation(conversation_id, manager) |
|
|
) |
|
|
|
|
|
conv_info.status = ConversationStatus.RUNNING |
|
|
await self._send_status_update(conversation_id, ConversationStatus.RUNNING) |
|
|
|
|
|
logger.info(f"Started conversation {conversation_id}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to start conversation {conversation_id}: {e}") |
|
|
await self._send_error(conversation_id, f"Failed to start conversation: {str(e)}") |
|
|
|
|
|
|
|
|
if conversation_id in self.active_conversations: |
|
|
del self.active_conversations[conversation_id] |
|
|
|
|
|
return False |
|
|
|
|
|
async def stop_conversation(self, conversation_id: str) -> bool: |
|
|
"""Stop an active conversation. |
|
|
|
|
|
Args: |
|
|
conversation_id: ID of conversation to stop |
|
|
|
|
|
Returns: |
|
|
True if conversation stopped successfully |
|
|
""" |
|
|
if conversation_id not in self.active_conversations and conversation_id not in self.active_human_chats: |
|
|
logger.warning(f"Conversation {conversation_id} not found") |
|
|
return False |
|
|
|
|
|
if conversation_id in self.active_human_chats: |
|
|
chat_info = self.active_human_chats[conversation_id] |
|
|
try: |
|
|
chat_info.stop_requested = True |
|
|
chat_info.status = ConversationStatus.STOPPING |
|
|
await self._send_status_update(conversation_id, ConversationStatus.STOPPING) |
|
|
|
|
|
chat_info.status = ConversationStatus.COMPLETED |
|
|
await self._send_status_update(conversation_id, ConversationStatus.COMPLETED) |
|
|
|
|
|
self.active_human_chats.pop(conversation_id, None) |
|
|
self.transcripts.pop(conversation_id, None) |
|
|
|
|
|
try: |
|
|
if chat_info.client is not None: |
|
|
await chat_info.client.close() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
logger.info(f"Stopped human chat {conversation_id}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error stopping human chat {conversation_id}: {e}") |
|
|
await self._send_error(conversation_id, f"Error stopping conversation: {str(e)}") |
|
|
return False |
|
|
|
|
|
conv_info = self.active_conversations[conversation_id] |
|
|
|
|
|
try: |
|
|
conv_info.stop_requested = True |
|
|
conv_info.status = ConversationStatus.STOPPING |
|
|
await self._send_status_update(conversation_id, ConversationStatus.STOPPING) |
|
|
|
|
|
|
|
|
if conv_info.task and not conv_info.task.done(): |
|
|
conv_info.task.cancel() |
|
|
try: |
|
|
await conv_info.task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
|
|
|
conv_info.status = ConversationStatus.COMPLETED |
|
|
await self._send_status_update(conversation_id, ConversationStatus.COMPLETED) |
|
|
|
|
|
del self.active_conversations[conversation_id] |
|
|
self.transcripts.pop(conversation_id, None) |
|
|
logger.info(f"Stopped conversation {conversation_id}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error stopping conversation {conversation_id}: {e}") |
|
|
conv_info.status = ConversationStatus.ERROR |
|
|
await self._send_error(conversation_id, f"Error stopping conversation: {str(e)}") |
|
|
return False |
|
|
|
|
|
async def get_conversation_status(self, conversation_id: str) -> Optional[Dict]: |
|
|
"""Get status of a conversation. |
|
|
|
|
|
Args: |
|
|
conversation_id: ID of conversation |
|
|
|
|
|
Returns: |
|
|
Dict with conversation status or None if not found |
|
|
""" |
|
|
if conversation_id not in self.active_conversations: |
|
|
return None |
|
|
|
|
|
conv_info = self.active_conversations[conversation_id] |
|
|
return { |
|
|
"conversation_id": conversation_id, |
|
|
"status": conv_info.status.value, |
|
|
"surveyor_persona_id": conv_info.surveyor_persona_id, |
|
|
"patient_persona_id": conv_info.patient_persona_id, |
|
|
"created_at": conv_info.created_at.isoformat(), |
|
|
"message_count": conv_info.message_count |
|
|
} |
|
|
|
|
|
async def list_active_conversations(self) -> Dict[str, Dict]: |
|
|
"""List all active conversations. |
|
|
|
|
|
Returns: |
|
|
Dict mapping conversation_id to status info |
|
|
""" |
|
|
result = {} |
|
|
for conv_id, conv_info in self.active_conversations.items(): |
|
|
result[conv_id] = { |
|
|
"status": conv_info.status.value, |
|
|
"surveyor_persona_id": conv_info.surveyor_persona_id, |
|
|
"patient_persona_id": conv_info.patient_persona_id, |
|
|
"created_at": conv_info.created_at.isoformat(), |
|
|
"message_count": conv_info.message_count |
|
|
} |
|
|
return result |
|
|
|
|
|
async def _stream_conversation(self, conversation_id: str, manager: ConversationManager): |
|
|
"""Stream conversation messages to WebSocket clients. |
|
|
|
|
|
Args: |
|
|
conversation_id: ID of the conversation |
|
|
manager: ConversationManager instance to stream from |
|
|
""" |
|
|
conv_info = self.active_conversations.get(conversation_id) |
|
|
if not conv_info: |
|
|
return |
|
|
|
|
|
try: |
|
|
async for message in manager.conduct_conversation(): |
|
|
|
|
|
if conversation_id not in self.active_conversations: |
|
|
break |
|
|
|
|
|
|
|
|
conv_info.message_count += 1 |
|
|
|
|
|
|
|
|
try: |
|
|
self.transcripts.setdefault(conversation_id, []).append({ |
|
|
"index": conv_info.message_count - 1, |
|
|
"role": message.get("role", "unknown"), |
|
|
"persona": message.get("persona"), |
|
|
"content": message.get("content", ""), |
|
|
"timestamp": message.get("timestamp"), |
|
|
}) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
websocket_message = { |
|
|
"type": "conversation_message", |
|
|
"conversation_id": conversation_id, |
|
|
**message |
|
|
} |
|
|
|
|
|
|
|
|
await self.websocket_manager.send_to_conversation( |
|
|
conversation_id, websocket_message |
|
|
) |
|
|
|
|
|
logger.info(f"Streamed message {conv_info.message_count} for conversation {conversation_id}: {message.get('role', 'unknown')} - {len(message.get('content', ''))} chars") |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
logger.info(f"Conversation {conversation_id} streaming cancelled") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Error streaming conversation {conversation_id}: {e}") |
|
|
conv_info.status = ConversationStatus.ERROR |
|
|
await self._send_error(conversation_id, f"Streaming error: {str(e)}") |
|
|
finally: |
|
|
|
|
|
try: |
|
|
await manager.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
if conv_info.status != ConversationStatus.ERROR: |
|
|
conv_info.status = ConversationStatus.COMPLETED |
|
|
await self._send_status_update(conversation_id, ConversationStatus.COMPLETED) |
|
|
|
|
|
|
|
|
if not conv_info.stop_requested: |
|
|
asked_ids = None |
|
|
try: |
|
|
asked_ids = list(getattr(manager, "asked_question_ids", None) or []) |
|
|
except Exception: |
|
|
asked_ids = None |
|
|
await self._run_resource_agent(conversation_id, asked_question_ids=asked_ids) |
|
|
|
|
|
|
|
|
self.active_conversations.pop(conversation_id, None) |
|
|
self.transcripts.pop(conversation_id, None) |
|
|
|
|
|
|
|
|
if conversation_id not in self.active_conversations: |
|
|
self.transcripts.pop(conversation_id, None) |
|
|
|
|
|
async def _run_resource_agent(self, conversation_id: str, *, asked_question_ids: Optional[List[str]] = None): |
|
|
"""Run post-conversation resource agent analysis and broadcast results.""" |
|
|
transcript = self.transcripts.get(conversation_id, []) |
|
|
if not transcript: |
|
|
return |
|
|
|
|
|
await self.websocket_manager.send_to_conversation(conversation_id, { |
|
|
"type": "resource_agent_status", |
|
|
"conversation_id": conversation_id, |
|
|
"status": "running", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
}) |
|
|
|
|
|
conv_info = self.active_conversations.get(conversation_id) or self.active_human_chats.get(conversation_id) |
|
|
if not conv_info: |
|
|
return |
|
|
try: |
|
|
seal_timestamp = datetime.now().isoformat() |
|
|
parsed = await run_resource_agent_analysis( |
|
|
transcript=transcript, |
|
|
llm_backend=conv_info.llm_backend, |
|
|
host=conv_info.host, |
|
|
model=conv_info.model, |
|
|
settings=self.settings, |
|
|
analysis_attributes=getattr(conv_info, "analysis_attributes", None), |
|
|
analysis_system_prompt=getattr(conv_info, "analysis_system_prompt", None), |
|
|
) |
|
|
|
|
|
persisted = False |
|
|
run_id = None |
|
|
try: |
|
|
store = get_run_store() |
|
|
mode = "human_to_ai" if conversation_id in self.active_human_chats else "ai_to_ai" |
|
|
|
|
|
persona_snapshots: Dict[str, Dict[str, Any]] = {} |
|
|
try: |
|
|
surveyor_persona = self.persona_system.get_persona(conv_info.surveyor_persona_id) or {} |
|
|
patient_persona = self.persona_system.get_persona(conv_info.patient_persona_id) or {} |
|
|
persona_snapshots = { |
|
|
"surveyor": { |
|
|
"persona_id": conv_info.surveyor_persona_id, |
|
|
"persona_version_id": surveyor_persona.get("version_id"), |
|
|
"snapshot": surveyor_persona, |
|
|
}, |
|
|
"patient": { |
|
|
"persona_id": conv_info.patient_persona_id, |
|
|
"persona_version_id": patient_persona.get("version_id"), |
|
|
"snapshot": patient_persona, |
|
|
}, |
|
|
} |
|
|
except Exception: |
|
|
persona_snapshots = {} |
|
|
|
|
|
config_snapshot: Dict[str, Any] = { |
|
|
"llm": { |
|
|
"backend": conv_info.llm_backend, |
|
|
"host": conv_info.host, |
|
|
"model": conv_info.model, |
|
|
"timeout": self.settings.llm.timeout, |
|
|
"max_retries": self.settings.llm.max_retries, |
|
|
"retry_delay": self.settings.llm.retry_delay, |
|
|
}, |
|
|
"personas": { |
|
|
"surveyor_persona_id": conv_info.surveyor_persona_id, |
|
|
"patient_persona_id": conv_info.patient_persona_id, |
|
|
"surveyor_system_prompt": getattr(conv_info, "surveyor_system_prompt", None), |
|
|
"patient_system_prompt": getattr(conv_info, "patient_system_prompt", None), |
|
|
"patient_attributes": getattr(conv_info, "patient_attributes", None), |
|
|
"surveyor_attributes": getattr(conv_info, "surveyor_attributes", None), |
|
|
"surveyor_question_bank": getattr(conv_info, "surveyor_question_bank", None), |
|
|
"asked_question_ids": asked_question_ids, |
|
|
}, |
|
|
"analysis": { |
|
|
"analysis_system_prompt": getattr(conv_info, "analysis_system_prompt", None), |
|
|
"analysis_attributes": getattr(conv_info, "analysis_attributes", None), |
|
|
}, |
|
|
} |
|
|
|
|
|
run_id = conversation_id |
|
|
record = RunRecord( |
|
|
run_id=run_id, |
|
|
mode=mode, |
|
|
status="completed", |
|
|
created_at=getattr(conv_info, "created_at").isoformat(), |
|
|
ended_at=seal_timestamp, |
|
|
sealed_at=seal_timestamp, |
|
|
title=None, |
|
|
input_summary=None, |
|
|
config=config_snapshot, |
|
|
messages=transcript, |
|
|
analyses={"resource_agent_v2": parsed}, |
|
|
persona_snapshots=persona_snapshots, |
|
|
) |
|
|
await store.save_sealed_run(record) |
|
|
persisted = True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to persist sealed run {conversation_id}: {e}") |
|
|
|
|
|
await self.websocket_manager.send_to_conversation(conversation_id, { |
|
|
"type": "resource_agent_result", |
|
|
"conversation_id": conversation_id, |
|
|
"run_id": run_id if persisted else None, |
|
|
"persisted": persisted, |
|
|
"data": parsed, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
}) |
|
|
await self.websocket_manager.send_to_conversation(conversation_id, { |
|
|
"type": "resource_agent_status", |
|
|
"conversation_id": conversation_id, |
|
|
"status": "complete", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
}) |
|
|
except Exception as e: |
|
|
logger.error(f"Resource agent failed for {conversation_id}: {e}") |
|
|
await self.websocket_manager.send_to_conversation(conversation_id, { |
|
|
"type": "resource_agent_status", |
|
|
"conversation_id": conversation_id, |
|
|
"status": "error", |
|
|
"error": str(e), |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
}) |
|
|
|
|
|
def _build_llm_parameters(self) -> Dict[str, Any]: |
|
|
"""Prepare keyword arguments for LLM client creation.""" |
|
|
params: Dict[str, Any] = { |
|
|
"timeout": self.settings.llm.timeout, |
|
|
"max_retries": self.settings.llm.max_retries, |
|
|
"retry_delay": self.settings.llm.retry_delay, |
|
|
} |
|
|
|
|
|
if self.settings.llm.api_key: |
|
|
params["api_key"] = self.settings.llm.api_key |
|
|
if self.settings.llm.site_url: |
|
|
params["site_url"] = self.settings.llm.site_url |
|
|
if self.settings.llm.app_name: |
|
|
params["app_name"] = self.settings.llm.app_name |
|
|
|
|
|
return params |
|
|
|
|
|
async def _send_status_update(self, conversation_id: str, status: ConversationStatus): |
|
|
"""Send conversation status update to clients. |
|
|
|
|
|
Args: |
|
|
conversation_id: ID of the conversation |
|
|
status: New conversation status |
|
|
""" |
|
|
message = { |
|
|
"type": "conversation_status", |
|
|
"conversation_id": conversation_id, |
|
|
"status": status.value, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
await self.websocket_manager.send_to_conversation(conversation_id, message) |
|
|
|
|
|
async def _send_error(self, conversation_id: str, error_message: str): |
|
|
"""Send error message to clients. |
|
|
|
|
|
Args: |
|
|
conversation_id: ID of the conversation |
|
|
error_message: Error description |
|
|
""" |
|
|
message = { |
|
|
"type": "conversation_error", |
|
|
"conversation_id": conversation_id, |
|
|
"error": error_message, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
await self.websocket_manager.send_to_conversation(conversation_id, message) |
|
|
|
|
|
async def cleanup(self): |
|
|
"""Clean up all active conversations.""" |
|
|
for conversation_id in list(self.active_conversations.keys()): |
|
|
await self.stop_conversation(conversation_id) |
|
|
for conversation_id in list(self.active_human_chats.keys()): |
|
|
await self.stop_conversation(conversation_id) |
|
|
|
|
|
|
|
|
|
|
|
conversation_service: Optional[ConversationService] = None |
|
|
|
|
|
|
|
|
def get_conversation_service() -> ConversationService: |
|
|
"""Get the global conversation service instance. |
|
|
|
|
|
Returns: |
|
|
ConversationService instance |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If service not initialized |
|
|
""" |
|
|
if conversation_service is None: |
|
|
raise RuntimeError("ConversationService not initialized") |
|
|
return conversation_service |
|
|
|
|
|
|
|
|
def initialize_conversation_service(websocket_manager: ConnectionManager, settings: Optional[AppSettings] = None): |
|
|
"""Initialize the global conversation service. |
|
|
|
|
|
Args: |
|
|
websocket_manager: WebSocket connection manager |
|
|
settings: Shared application settings (optional) |
|
|
""" |
|
|
global conversation_service |
|
|
conversation_service = ConversationService(websocket_manager, settings=settings) |
|
|
logger.info("ConversationService initialized") |
|
|
|