meet4150/alive_pine / app /agent /health_agent.py
download
raw
11 kB
from __future__ import annotations
import os
from pathlib import Path
from threading import Lock
import ollama
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from app.agent.kb_retrieval import retrieve_with_scores
from app.nlp.nlp_service import NLPService
class HFChatService:
_instance: "HFChatService | None" = None
_instance_lock = Lock()
_model_lock = Lock()
_default_model_name = os.getenv("ALIVEAI_HF_MODEL", "google/flan-t5-small")
_local_model_dir = (
Path(__file__).resolve().parents[2] / "models" / "google__flan-t5-small"
)
def __new__(cls) -> "HFChatService":
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._tokenizer = None
cls._instance._model = None
return cls._instance
def _load(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
with self._model_lock:
if self._model is not None and self._tokenizer is not None:
return
model_source = (
str(self._local_model_dir)
if self._local_model_dir.exists()
else self._default_model_name
)
self._tokenizer = AutoTokenizer.from_pretrained(model_source)
self._model = AutoModelForSeq2SeqLM.from_pretrained(model_source)
self._model.eval()
print(f"Hugging Face fallback model loaded: {self._default_model_name}")
def generate(self, prompt: str, max_new_tokens: int = 220) -> str:
self._load()
inputs = self._tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
)
with torch.no_grad():
output_ids = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
)
return self._tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
class HealthAgent:
_section_summary = "Summary:"
_section_causes = "Possible causes from context:"
_section_seek_care = "When to seek care:"
def __init__(self, model: str = "llama3.2:3b") -> None:
self.model = model
self.conversation_history: list[dict[str, str]] = []
self.system_prompt = (
"You are a compassionate medical assistant. Answer health questions clearly "
"in simple, layman language. Only use the provided context to answer. "
"If context is insufficient, say you don't have enough information. "
"Never diagnose. Always recommend consulting a doctor for serious symptoms. "
"Respond strictly using this format:\n"
"Summary:\n"
"<2-3 concise lines>\n\n"
"Possible causes from context:\n"
"- <cause from provided sources>\n"
"- <cause from provided sources>\n\n"
"When to seek care:\n"
"- <urgent warning signs>\n"
"- <when to contact a doctor>"
)
def _append_history(self, user_message: str, assistant_message: str) -> None:
self.conversation_history.append({"role": "user", "content": user_message})
self.conversation_history.append({"role": "assistant", "content": assistant_message})
self.conversation_history = self.conversation_history[-20:]
@staticmethod
def _extract_possible_causes(chunks: list[dict], limit: int = 3) -> list[str]:
causes: list[str] = []
seen: set[str] = set()
for chunk in chunks:
content = (chunk.get("content") or "").strip()
if not content:
continue
if ": " in content:
content = content.split(": ", 1)[1]
sentence = content.split(". ")[0].strip(" .")
if not sentence:
continue
normalized = sentence.lower()
if normalized in seen:
continue
seen.add(normalized)
causes.append(sentence)
if len(causes) >= limit:
break
return causes
@classmethod
def _structured_response(cls, summary: str, causes: list[str]) -> str:
safe_summary = (summary or "").strip() or "I could not find enough context to give a complete explanation."
safe_causes = causes or ["Not enough specific context was retrieved from the knowledge base."]
seek_care = [
"Seek emergency care right away for severe chest pain, breathing difficulty, fainting, or confusion.",
"Consult a doctor soon if symptoms continue, worsen, or keep returning.",
]
causes_lines = "\n".join(f"- {cause}" for cause in safe_causes)
seek_care_lines = "\n".join(f"- {line}" for line in seek_care)
return (
f"{cls._section_summary}\n"
f"{safe_summary}\n\n"
f"{cls._section_causes}\n"
f"{causes_lines}\n\n"
f"{cls._section_seek_care}\n"
f"{seek_care_lines}"
)
@classmethod
def _extractive_fallback_answer(cls, chunks: list[dict]) -> str:
if not chunks:
return cls._structured_response(
"I don't have enough information in my current knowledge base to answer that clearly.",
[],
)
evidence = " ".join(chunk["content"] for chunk in chunks[:2]).strip()[:300]
causes = cls._extract_possible_causes(chunks)
return cls._structured_response(
f"Based on the retrieved medical context: {evidence}",
causes,
)
@classmethod
def _has_required_sections(cls, answer: str) -> bool:
normalized = (answer or "").lower()
return (
cls._section_summary.lower() in normalized
and cls._section_causes.lower() in normalized
and cls._section_seek_care.lower() in normalized
)
@staticmethod
def _is_poorly_formatted_structured_answer(answer: str) -> bool:
normalized = (answer or "").strip().lower()
if "cause from provided sources>" in normalized:
return True
if "summary:\nsummary:" in normalized:
return True
if len(normalized) < 120:
return True
return False
@classmethod
def _enforce_response_structure(cls, answer: str, chunks: list[dict]) -> str:
if cls._has_required_sections(answer) and not cls._is_poorly_formatted_structured_answer(answer):
return answer
if cls._is_poorly_formatted_structured_answer(answer):
return cls._extractive_fallback_answer(chunks)
concise_summary = (answer or "").strip().replace("\n", " ")[:320]
causes = cls._extract_possible_causes(chunks)
return cls._structured_response(concise_summary, causes)
@staticmethod
def _is_low_quality_generated_answer(answer: str, intent: str) -> bool:
normalized = (answer or "").strip().lower()
if len(normalized) < 20:
return True
if intent != "emergency" and ("call 911" in normalized or "call emergency services" in normalized):
return True
return False
def chat(self, user_message: str) -> dict:
nlp_result = NLPService().process(user_message)
intent = nlp_result["intent"]
disease_id = nlp_result["disease_id"]
if intent == "emergency":
return {
"response": "⚠️ This sounds like a medical emergency. Please call emergency services (112) immediately.",
"intent": "emergency",
"context_used": [],
}
if intent == "greeting":
return {
"response": "Hello! I'm your health assistant. How can I help you today?",
"intent": "greeting",
"context_used": [],
}
chunks = retrieve_with_scores(user_message, disease_id=disease_id, top_k=5)
if not chunks:
answer = self._extractive_fallback_answer(chunks)
self._append_history(user_message, answer)
return {
"response": answer,
"intent": intent,
"disease_id": disease_id,
"context_used": [],
"nlp_confidence": nlp_result["intent_confidence"],
}
context = "\n\n".join(
[f"[Source {index + 1}]: {chunk['content']}" for index, chunk in enumerate(chunks)]
)
messages = [
{"role": "system", "content": f"{self.system_prompt}\n\nContext:\n{context}"},
*self.conversation_history[-10:],
{"role": "user", "content": user_message},
]
try:
response = ollama.chat(model=self.model, messages=messages)
answer = response["message"]["content"].strip()
except Exception as ollama_exc: # pragma: no cover - depends on local Ollama runtime
fallback_mode = os.getenv("ALIVEAI_CHAT_FALLBACK", "auto").lower()
should_try_hf = fallback_mode in {"auto", "hf"}
if should_try_hf:
try:
history_snippet = "\n".join(
[f"{item['role'].title()}: {item['content']}" for item in self.conversation_history[-6:]]
)
hf_prompt = (
f"{self.system_prompt}\n\n"
f"Context:\n{context}\n\n"
f"Conversation history:\n{history_snippet}\n\n"
f"User: {user_message}\nAssistant:"
)
answer = HFChatService().generate(hf_prompt)
if self._is_low_quality_generated_answer(answer, intent):
answer = self._extractive_fallback_answer(chunks)
except Exception as hf_exc:
answer = self._extractive_fallback_answer(chunks)
answer += (
f" (Generation backend note: Ollama unavailable: {ollama_exc}; "
f"Hugging Face unavailable: {hf_exc})"
)
else:
answer = self._extractive_fallback_answer(chunks)
answer = self._enforce_response_structure(answer, chunks)
self._append_history(user_message, answer)
return {
"response": answer,
"intent": intent,
"disease_id": disease_id,
"context_used": [{"content": chunk["content"][:100], "score": chunk["score"]} for chunk in chunks],
"nlp_confidence": nlp_result["intent_confidence"],
}
def reset(self) -> None:
self.conversation_history.clear()

Xet Storage Details

Size:
11 kB
·
Xet hash:
ccd9a1b7067ba2e9325e202583e52698d8f7eb43a01216b3e340812893b38027

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.