""" utils/symptom_validator.py — LLM-based free-text symptom validator. Uses Groq (llama-3.3-70b-versatile) to: 1. Validate each symptom is a real medical/respiratory symptom 2. Reject non-medical text (e.g. "love failure", "stressed", "bad day") 3. Map valid symptoms to a numeric boost score (0.0–1.0) that gets added to the symptom_index in the pipeline Returns: { 'valid': [list of accepted symptom strings], 'invalid': [list of rejected symptom strings with reasons], 'boost': float — extra score to add to symptom_index (0.0–0.25 max) 'summary': str — short human-readable summary } """ import os import json import re # Try to load .env manually (no python-dotenv required) def _load_env_key(): key = os.environ.get("GROQ_API_KEY", "") if key: return key # Walk up from this file to find .env base = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) env_path = os.path.join(base, ".env") if os.path.exists(env_path): with open(env_path) as f: for line in f: line = line.strip() if line.startswith("GROQ_API_KEY="): return line.split("=", 1)[1].strip().strip('"').strip("'") return "" _SYSTEM_PROMPT = """You are a strict medical symptom validator for a respiratory triage AI system. Your job: 1. Evaluate each symptom the user entered. 2. Accept ONLY real medical/physical symptoms that could be relevant to respiratory or general health assessment. 3. Reject anything that is NOT a medical symptom — emotions, life events, relationship problems, vague non-medical phrases, nonsense, etc. 4. For each VALID symptom, assign a respiratory_relevance score (0.0 to 1.0): - 1.0 = directly respiratory (e.g. "shortness of breath", "wheezing", "chest pain") - 0.5 = general medical but relevant (e.g. "fever", "fatigue", "headache") - 0.2 = mildly relevant (e.g. "nausea", "loss of appetite") 5. Return ONLY a JSON object, no extra text. JSON format: { "results": [ { "symptom": "", "valid": true or false, "reason": "", "respiratory_relevance": 0.0 } ], "boost": , "summary": "" } Rules: - "love failure", "heartbreak", "stress", "bad day", "boredom" → INVALID - "chest pain", "cough", "breathlessness", "wheezing", "fever", "sore throat", "runny nose" → VALID - "tired", "fatigue", "weakness" → VALID (respiratory_relevance: 0.3) - The boost value = sum of (respiratory_relevance * 0.1) for all valid symptoms, capped at 0.25 - Be strict. When in doubt, reject it.""" def validate_symptoms(raw_text: str) -> dict: """ Validate free-text symptoms using Groq LLM. Parameters ---------- raw_text : str Comma or newline separated symptom text from patient input. Returns ------- dict with keys: valid, invalid, boost, summary, raw_results """ # Default safe return default = { 'valid': [], 'invalid': [], 'boost': 0.0, 'summary': 'No additional symptoms provided.', 'raw_results': [] } if not raw_text or not raw_text.strip(): return default api_key = _load_env_key() if not api_key: print("[symptom_validator] No GROQ_API_KEY found — skipping LLM validation") return default # Split input into individual symptoms symptoms = [s.strip() for s in re.split(r'[,\n;]+', raw_text) if s.strip()] if not symptoms: return default # Cap at 10 symptoms symptoms = symptoms[:10] symptom_list = "\n".join(f"- {s}" for s in symptoms) try: from groq import Groq client = Groq(api_key=api_key) response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": f"Validate these symptoms:\n{symptom_list}"} ], temperature=0.1, max_tokens=800, ) content = response.choices[0].message.content.strip() # Extract JSON from response (handle markdown code blocks) json_match = re.search(r'\{[\s\S]*\}', content) if not json_match: print(f"[symptom_validator] Could not parse JSON from: {content[:200]}") return default data = json.loads(json_match.group()) results = data.get('results', []) valid = [r['reason'] for r in results if r.get('valid')] invalid = [ {'symptom': r['symptom'], 'reason': r.get('reason', 'Not a medical symptom')} for r in results if not r.get('valid') ] boost = float(data.get('boost', 0.0)) boost = max(0.0, min(0.25, boost)) # hard cap summary = data.get('summary', '') print(f"[symptom_validator] Valid: {valid} | Invalid: {[i['symptom'] for i in invalid]} | Boost: {boost:.3f}") return { 'valid': valid, 'invalid': invalid, 'boost': boost, 'summary': summary, 'raw_results': results, } except Exception as e: print(f"[symptom_validator] Error: {e}") return default