respitriage / utils /symptom_validator.py
SujalSha's picture
Upload folder using huggingface_hub
d0ace1e verified
"""
utils/symptom_validator.py — LLM-based free-text symptom validator.
Uses Groq (llama-3.3-70b-versatile) to:
1. Validate each symptom is a real medical/respiratory symptom
2. Reject non-medical text (e.g. "love failure", "stressed", "bad day")
3. Map valid symptoms to a numeric boost score (0.0–1.0) that gets
added to the symptom_index in the pipeline
Returns:
{
'valid': [list of accepted symptom strings],
'invalid': [list of rejected symptom strings with reasons],
'boost': float — extra score to add to symptom_index (0.0–0.25 max)
'summary': str — short human-readable summary
}
"""
import os
import json
import re
# Try to load .env manually (no python-dotenv required)
def _load_env_key():
key = os.environ.get("GROQ_API_KEY", "")
if key:
return key
# Walk up from this file to find .env
base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env_path = os.path.join(base, ".env")
if os.path.exists(env_path):
with open(env_path) as f:
for line in f:
line = line.strip()
if line.startswith("GROQ_API_KEY="):
return line.split("=", 1)[1].strip().strip('"').strip("'")
return ""
_SYSTEM_PROMPT = """You are a strict medical symptom validator for a respiratory triage AI system.
Your job:
1. Evaluate each symptom the user entered.
2. Accept ONLY real medical/physical symptoms that could be relevant to respiratory or general health assessment.
3. Reject anything that is NOT a medical symptom — emotions, life events, relationship problems, vague non-medical phrases, nonsense, etc.
4. For each VALID symptom, assign a respiratory_relevance score (0.0 to 1.0):
- 1.0 = directly respiratory (e.g. "shortness of breath", "wheezing", "chest pain")
- 0.5 = general medical but relevant (e.g. "fever", "fatigue", "headache")
- 0.2 = mildly relevant (e.g. "nausea", "loss of appetite")
5. Return ONLY a JSON object, no extra text.
JSON format:
{
"results": [
{
"symptom": "<original text>",
"valid": true or false,
"reason": "<brief reason if invalid, or accepted symptom name if valid>",
"respiratory_relevance": 0.0
}
],
"boost": <float 0.0-0.25, total extra score from all valid symptoms>,
"summary": "<one sentence summary>"
}
Rules:
- "love failure", "heartbreak", "stress", "bad day", "boredom" → INVALID
- "chest pain", "cough", "breathlessness", "wheezing", "fever", "sore throat", "runny nose" → VALID
- "tired", "fatigue", "weakness" → VALID (respiratory_relevance: 0.3)
- The boost value = sum of (respiratory_relevance * 0.1) for all valid symptoms, capped at 0.25
- Be strict. When in doubt, reject it."""
def validate_symptoms(raw_text: str) -> dict:
"""
Validate free-text symptoms using Groq LLM.
Parameters
----------
raw_text : str
Comma or newline separated symptom text from patient input.
Returns
-------
dict with keys: valid, invalid, boost, summary, raw_results
"""
# Default safe return
default = {
'valid': [], 'invalid': [], 'boost': 0.0,
'summary': 'No additional symptoms provided.', 'raw_results': []
}
if not raw_text or not raw_text.strip():
return default
api_key = _load_env_key()
if not api_key:
print("[symptom_validator] No GROQ_API_KEY found — skipping LLM validation")
return default
# Split input into individual symptoms
symptoms = [s.strip() for s in re.split(r'[,\n;]+', raw_text) if s.strip()]
if not symptoms:
return default
# Cap at 10 symptoms
symptoms = symptoms[:10]
symptom_list = "\n".join(f"- {s}" for s in symptoms)
try:
from groq import Groq
client = Groq(api_key=api_key)
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": f"Validate these symptoms:\n{symptom_list}"}
],
temperature=0.1,
max_tokens=800,
)
content = response.choices[0].message.content.strip()
# Extract JSON from response (handle markdown code blocks)
json_match = re.search(r'\{[\s\S]*\}', content)
if not json_match:
print(f"[symptom_validator] Could not parse JSON from: {content[:200]}")
return default
data = json.loads(json_match.group())
results = data.get('results', [])
valid = [r['reason'] for r in results if r.get('valid')]
invalid = [
{'symptom': r['symptom'], 'reason': r.get('reason', 'Not a medical symptom')}
for r in results if not r.get('valid')
]
boost = float(data.get('boost', 0.0))
boost = max(0.0, min(0.25, boost)) # hard cap
summary = data.get('summary', '')
print(f"[symptom_validator] Valid: {valid} | Invalid: {[i['symptom'] for i in invalid]} | Boost: {boost:.3f}")
return {
'valid': valid,
'invalid': invalid,
'boost': boost,
'summary': summary,
'raw_results': results,
}
except Exception as e:
print(f"[symptom_validator] Error: {e}")
return default