psypredict-backend / app /services /ollama_engine.py
therandomuser03's picture
feat: replace Ollama with Groq API (llama-3.3-70b-versatile)
befb434
"""
ollama_engine.py β€” PsyPredict LLM Engine (Groq / Llama3.3-70B)
Replaces Ollama with Groq's API. Same interface β€” no other files need changing.
Features:
- Groq API via httpx (OpenAI-compatible endpoint)
- Structured JSON output via ---JSON--- marker + PsychReport schema
- Streaming support
- Retry with exponential backoff
- Graceful fallback if API key missing or Groq unreachable
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import AsyncIterator, List, Optional
import httpx
from app.config import get_settings
from app.schemas import (
ConversationMessage,
PsychReport,
RiskLevel,
fallback_report,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Groq API base URL (OpenAI-compatible)
# ---------------------------------------------------------------------------
GROQ_API_BASE = "https://api.groq.com/openai/v1"
# ---------------------------------------------------------------------------
# System Prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are a compassionate clinical AI therapist integrated into PsyPredict, a mental health platform.
Your role is twofold:
1. Respond as a warm, empathetic therapist β€” never robotic, never dismissive.
2. Provide a structured backend psychological assessment in JSON format.
== CONVERSATIONAL RESPONSE RULES ==
- ALWAYS give a full, thoughtful, empathetic response FIRST (before the JSON block).
- Responses must be at least 3-5 sentences. Never one-liners.
- Validate the user's feelings. Reflect back what they shared. Show you truly listened.
- Do NOT start with "I'm here to help" or generic openers. Be specific to what they said.
- Use warm, humanizing language. Be like a therapist who genuinely cares, not a support chatbot.
- If the situation involves trauma, grief, betrayal, or crisis β€” respond with appropriate gravity and compassion.
- Suggest one concrete, actionable step at the end of your reply.
- Do NOT mention the JSON block, schema, or any technical terms in your reply.
== JSON ASSESSMENT RULES ==
After your conversational response, add the marker: ---JSON---
Then provide the PsychReport JSON.
1. Output ONLY valid JSON conforming exactly to the PsychReport schema below.
2. Do NOT fabricate clinical diagnoses. Infer only from the evidence provided.
3. cognitive_distortions must reference recognized CBT distortion labels only.
4. suggested_interventions must be concrete and clinically actionable.
5. confidence_score reflects YOUR confidence in this assessment (0.0 to 1.0).
6. crisis_triggered MUST be false β€” crisis detection is handled by a separate layer.
7. service_degraded MUST be false.
PSYCH_REPORT_SCHEMA:
{
"risk_classification": "<MINIMAL|LOW|MODERATE|HIGH|CRITICAL>",
"emotional_state_summary": "<string>",
"behavioral_inference": "<string>",
"cognitive_distortions": ["<string>", ...],
"suggested_interventions": ["<string>", ...],
"confidence_score": <float 0.0-1.0>,
"crisis_triggered": false,
"crisis_resources": null,
"service_degraded": false
}
Output format:
<Your full, empathetic therapist response here β€” 3-5 sentences minimum>
---JSON---
{ ...psych report json... }
"""
# ---------------------------------------------------------------------------
# FACE β†’ DISTRESS SCORE mapping
# ---------------------------------------------------------------------------
FACE_DISTRESS_MAP: dict[str, float] = {
"fear": 0.80,
"sad": 0.70,
"angry": 0.50,
"disgust": 0.40,
"surprised": 0.30,
"neutral": 0.20,
"happy": 0.05,
}
class OllamaEngine:
"""
LLM engine backed by Groq API (Llama3.3-70B).
Named OllamaEngine to preserve all existing imports across the codebase.
"""
def __init__(self) -> None:
self.settings = get_settings()
def _headers(self) -> dict:
return {
"Authorization": f"Bearer {self.settings.GROQ_API_KEY}",
"Content-Type": "application/json",
}
def _make_client(self, stream: bool = False) -> httpx.AsyncClient:
read_timeout = None if stream else float(self.settings.OLLAMA_TIMEOUT_S)
return httpx.AsyncClient(
base_url=GROQ_API_BASE,
headers=self._headers(),
timeout=httpx.Timeout(
connect=10.0,
read=read_timeout,
write=30.0,
pool=5.0,
),
)
# ------------------------------------------------------------------
# Health Check
# ------------------------------------------------------------------
async def is_reachable(self) -> bool:
"""Returns True if Groq API key is set and endpoint is reachable."""
if not self.settings.GROQ_API_KEY:
logger.warning("GROQ_API_KEY is not set.")
return False
try:
async with self._make_client() as client:
resp = await client.get("/models", timeout=5.0)
return resp.status_code == 200
except Exception:
return False
async def close(self) -> None:
pass
# ------------------------------------------------------------------
# Context Window Trimming
# ------------------------------------------------------------------
def _trim_history(
self, history: List[ConversationMessage]
) -> List[ConversationMessage]:
max_turns = self.settings.MAX_CONTEXT_TURNS
if len(history) <= max_turns * 2:
return history
return history[-(max_turns * 2):]
# ------------------------------------------------------------------
# Messages Builder (Groq uses chat format, not raw prompt)
# ------------------------------------------------------------------
def _build_messages(
self,
user_text: str,
face_emotion: str,
history: List[ConversationMessage],
text_emotion_summary: Optional[str] = None,
) -> list:
"""
Builds the messages array for Groq's chat completions API.
System prompt is a dedicated system message.
History becomes alternating user/assistant messages.
Multimodal context is appended to the final user message.
"""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
trimmed = self._trim_history(history)
for msg in trimmed:
messages.append({
"role": msg.role.value,
"content": msg.content,
})
face_distress = FACE_DISTRESS_MAP.get(face_emotion.lower(), 0.20)
multimodal_ctx = (
f"\n\n[MULTIMODAL CONTEXT]\n"
f"Face emotion (webcam): {face_emotion} (distress score: {face_distress:.2f})\n"
)
if text_emotion_summary:
multimodal_ctx += f"Text emotion (DistilBERT): {text_emotion_summary}\n"
final_user_content = user_text + multimodal_ctx
messages.append({"role": "user", "content": final_user_content})
return messages
# ------------------------------------------------------------------
# Parse LLM Output β†’ (reply_text, PsychReport)
# ------------------------------------------------------------------
def _parse_response(self, raw: str) -> tuple[str, PsychReport]:
marker = "---JSON---"
if marker in raw:
parts = raw.split(marker, 1)
reply_text = parts[0].strip()
json_block = parts[1].strip()
else:
reply_text = ""
json_block = raw.strip()
if json_block.startswith("```"):
lines = json_block.split("\n")
json_block = "\n".join(
l for l in lines if not l.startswith("```")
).strip()
try:
data = json.loads(json_block)
report = PsychReport(**data)
except (json.JSONDecodeError, ValueError, KeyError) as exc:
logger.warning(
"Failed to parse PsychReport from Groq output: %s | raw=%r",
exc,
json_block[:500],
)
report = fallback_report()
if not reply_text:
reply_text = raw.strip()
return reply_text, report
# ------------------------------------------------------------------
# Generate (non-streaming)
# ------------------------------------------------------------------
async def generate(
self,
user_text: str,
face_emotion: str = "neutral",
history: Optional[List[ConversationMessage]] = None,
text_emotion_summary: Optional[str] = None,
) -> tuple[str, PsychReport]:
if not self.settings.GROQ_API_KEY:
logger.warning("GROQ_API_KEY not set β€” returning fallback.")
return ("Groq API key is not configured.", fallback_report())
if history is None:
history = []
messages = self._build_messages(user_text, face_emotion, history, text_emotion_summary)
payload = {
"model": self.settings.GROQ_MODEL,
"messages": messages,
"temperature": 0.2,
"max_tokens": 1024,
"stream": False,
}
last_error: Optional[Exception] = None
delay = self.settings.OLLAMA_RETRY_DELAY_S
for attempt in range(1, self.settings.OLLAMA_RETRIES + 1):
try:
logger.info("Groq generate attempt %d/%d", attempt, self.settings.OLLAMA_RETRIES)
async with self._make_client() as client:
resp = await client.post("/chat/completions", json=payload)
resp.raise_for_status()
data = resp.json()
raw_text: str = data["choices"][0]["message"]["content"]
return self._parse_response(raw_text)
except httpx.TimeoutException as exc:
last_error = exc
logger.warning("Groq timeout on attempt %d: %s", attempt, exc)
except httpx.HTTPStatusError as exc:
last_error = exc
logger.error("Groq HTTP error %s: %s", exc.response.status_code, exc.response.text)
break
except Exception as exc:
last_error = exc
logger.error("Groq unexpected error: %s", exc)
if attempt < self.settings.OLLAMA_RETRIES:
await asyncio.sleep(delay)
delay *= 2
logger.error("All Groq attempts failed. Last error: %s", last_error)
return ("The inference service is temporarily unavailable. Please try again shortly.", fallback_report())
# ------------------------------------------------------------------
# Generate (streaming)
# ------------------------------------------------------------------
async def generate_stream(
self,
user_text: str,
face_emotion: str = "neutral",
history: Optional[List[ConversationMessage]] = None,
text_emotion_summary: Optional[str] = None,
) -> AsyncIterator[str]:
if not self.settings.GROQ_API_KEY:
logger.warning("GROQ_API_KEY not set β€” returning fallback stream.")
yield "Groq API key is not configured.\n---JSON---\n" + json.dumps(fallback_report().model_dump())
return
if history is None:
history = []
messages = self._build_messages(user_text, face_emotion, history, text_emotion_summary)
payload = {
"model": self.settings.GROQ_MODEL,
"messages": messages,
"temperature": 0.2,
"max_tokens": 1024,
"stream": True,
}
try:
async with self._make_client(stream=True) as client:
async with client.stream("POST", "/chat/completions", json=payload) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line.startswith("data: "):
continue
data_str = line[len("data: "):]
if data_str.strip() == "[DONE]":
break
try:
chunk = json.loads(data_str)
token = chunk["choices"][0].get("delta", {}).get("content", "")
if token:
yield token
except (json.JSONDecodeError, KeyError):
continue
except Exception as exc:
logger.error("Groq streaming failed: %s", exc)
yield "\n[Inference error β€” Groq request failed. Try again.]\n"
# ---------------------------------------------------------------------------
# Singleton β€” same name so all existing imports work unchanged
# ---------------------------------------------------------------------------
ollama_engine = OllamaEngine()