"""Judge handler for evidence assessment using PydanticAI.""" import asyncio import json from typing import Any, ClassVar import structlog from huggingface_hub import InferenceClient from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.huggingface import HuggingFaceProvider from pydantic_ai.providers.openai import OpenAIProvider from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from src.config.domain import ResearchDomain from src.prompts.judge import ( format_empty_evidence_prompt, format_user_prompt, get_system_prompt, select_evidence_for_judge, ) from src.utils.config import settings from src.utils.models import AssessmentDetails, Evidence, JudgeAssessment logger = structlog.get_logger() def _extract_titles_from_evidence( evidence: list[Evidence], max_items: int = 5, fallback_message: str | None = None ) -> list[str]: """Extract truncated titles from evidence for fallback display. Args: evidence: List of evidence items max_items: Maximum number of titles to extract fallback_message: Message to return if no evidence provided Returns: List of truncated titles (max 150 chars each) """ findings = [] for e in evidence[:max_items]: title = e.citation.title if len(title) > 150: title = title[:147] + "..." findings.append(title) if not findings and fallback_message: return [fallback_message] return findings def get_model() -> Any: """Get the LLM model based on configuration. Explicitly passes API keys from settings to avoid requiring users to export environment variables manually. """ llm_provider = settings.llm_provider if llm_provider == "anthropic": provider = AnthropicProvider(api_key=settings.anthropic_api_key) return AnthropicModel(settings.anthropic_model, provider=provider) if llm_provider == "huggingface": # Free tier - uses HF_TOKEN from environment if available model_name = settings.huggingface_model or "meta-llama/Llama-3.1-70B-Instruct" hf_provider = HuggingFaceProvider(api_key=settings.hf_token) return HuggingFaceModel(model_name, provider=hf_provider) if llm_provider != "openai": logger.warning("Unknown LLM provider, defaulting to OpenAI", provider=llm_provider) openai_provider = OpenAIProvider(api_key=settings.openai_api_key) return OpenAIChatModel(settings.openai_model, provider=openai_provider) class JudgeHandler: """ Handles evidence assessment using an LLM with structured output. Uses PydanticAI to ensure responses match the JudgeAssessment schema. """ def __init__( self, model: Any = None, domain: ResearchDomain | str | None = None, ) -> None: """ Initialize the JudgeHandler. Args: model: Optional PydanticAI model. If None, uses config default. domain: Research domain for prompt customization. """ self.model = model or get_model() self.domain = domain self.agent = Agent( model=self.model, output_type=JudgeAssessment, system_prompt=get_system_prompt(domain), retries=3, ) async def assess( self, question: str, evidence: list[Evidence], iteration: int = 0, max_iterations: int = 10, ) -> JudgeAssessment: """ Assess evidence and determine if it's sufficient. Args: question: The user's research question evidence: List of Evidence objects from search iteration: Current iteration number max_iterations: Maximum allowed iterations Returns: JudgeAssessment with evaluation results Raises: JudgeError: If assessment fails after retries """ logger.info( "Starting evidence assessment", question=question[:100], evidence_count=len(evidence), iteration=iteration, domain=self.domain, ) # Format the prompt based on whether we have evidence if evidence: # Select diverse evidence using embeddings (if available) selected_evidence = await select_evidence_for_judge(evidence, question) user_prompt = format_user_prompt( question, selected_evidence, iteration, max_iterations, total_evidence_count=len(evidence), domain=self.domain, ) else: user_prompt = format_empty_evidence_prompt(question) try: # Run the agent with structured output result = await self.agent.run(user_prompt) assessment = result.output logger.info( "Assessment complete", sufficient=assessment.sufficient, recommendation=assessment.recommendation, confidence=assessment.confidence, ) return assessment except Exception as e: logger.error("Assessment failed", error=str(e)) # Return a safe default assessment on failure return self._create_fallback_assessment(question, str(e)) def _create_fallback_assessment( self, question: str, error: str, ) -> JudgeAssessment: """ Create a fallback assessment when LLM fails. Args: question: The original question error: The error message Returns: Safe fallback JudgeAssessment """ return JudgeAssessment( details=AssessmentDetails( mechanism_score=0, mechanism_reasoning="Assessment failed due to LLM error", clinical_evidence_score=0, clinical_reasoning="Assessment failed due to LLM error", drug_candidates=[], key_findings=[], ), sufficient=False, confidence=0.0, recommendation="continue", next_search_queries=[ f"{question} mechanism", f"{question} clinical trials", f"{question} drug candidates", ], reasoning=f"Assessment failed: {error}. Recommend retrying with refined queries.", ) class HFInferenceJudgeHandler: """ JudgeHandler using HuggingFace Inference API for FREE LLM calls. Defaults to Llama-3.1-8B-Instruct (requires HF_TOKEN) or falls back to public models. """ FALLBACK_MODELS: ClassVar[list[str]] = [ "meta-llama/Llama-3.1-8B-Instruct", # Primary (Gated) "mistralai/Mistral-7B-Instruct-v0.3", # Secondary "HuggingFaceH4/zephyr-7b-beta", # Fallback (Ungated) ] # Force synthesis after N consecutive failures to prevent infinite loops # Rationale: 3 models x 3 retries each = 9 total API attempts before circuit break MAX_CONSECUTIVE_FAILURES: ClassVar[int] = 3 def __init__( self, model_id: str | None = None, domain: ResearchDomain | str | None = None, ) -> None: """ Initialize with HF Inference client. Args: model_id: Optional specific model ID. If None, uses FALLBACK_MODELS chain. domain: Research domain for prompt customization. """ self.model_id = model_id self.domain = domain # Will automatically use HF_TOKEN from env if available self.client = InferenceClient() self.call_count = 0 self.consecutive_failures = 0 # Track failures to prevent infinite loops self.last_question: str | None = None self.last_evidence: list[Evidence] | None = None async def assess( self, question: str, evidence: list[Evidence], iteration: int = 0, max_iterations: int = 10, ) -> JudgeAssessment: """ Assess evidence using HuggingFace Inference API. Attempts models in order until one succeeds. After MAX_CONSECUTIVE_FAILURES, forces synthesis to prevent infinite loops. """ self.call_count += 1 # Session-based reset: new question = new research session = reset failures # Prevents failure state from leaking across different user queries if question != self.last_question and self.last_question is not None: self.consecutive_failures = 0 self.last_question = question self.last_evidence = evidence # BUG FIX: After N consecutive failures, force synthesis to break infinite loop if self.consecutive_failures >= self.MAX_CONSECUTIVE_FAILURES: logger.warning( "Max consecutive failures reached, forcing synthesis", failures=self.consecutive_failures, evidence_count=len(evidence), ) return self._create_forced_synthesis_assessment(question, evidence) # Format the user prompt if evidence: selected_evidence = await select_evidence_for_judge(evidence, question) user_prompt = format_user_prompt( question, selected_evidence, iteration, max_iterations, total_evidence_count=len(evidence), domain=self.domain, ) else: user_prompt = format_empty_evidence_prompt(question) models_to_try: list[str] = [self.model_id] if self.model_id else self.FALLBACK_MODELS last_error: Exception | None = None for model in models_to_try: try: result = await self._call_with_retry(model, user_prompt, question) self.consecutive_failures = 0 # Reset on success return result except Exception as e: # Check for 402/Quota errors to fail fast error_str = str(e) if ( "402" in error_str or "quota" in error_str.lower() or "payment required" in error_str.lower() ): logger.error("HF Quota Exhausted", error=error_str) return self._create_quota_exhausted_assessment(question, evidence) logger.warning("Model failed", model=model, error=str(e)) last_error = e continue # All models failed - increment failure counter self.consecutive_failures += 1 logger.error( "All HF models failed", error=str(last_error), consecutive_failures=self.consecutive_failures, ) return self._create_fallback_assessment(question, str(last_error)) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=4), retry=retry_if_exception_type(Exception), reraise=True, ) async def _call_with_retry(self, model: str, prompt: str, question: str) -> JudgeAssessment: """Make API call with retry logic using chat_completion.""" loop = asyncio.get_running_loop() system_prompt = get_system_prompt(self.domain) # Build messages for chat_completion (model-agnostic) messages = [ { "role": "system", "content": f"""{system_prompt} IMPORTANT: Respond with ONLY valid JSON matching this schema: {{ "details": {{ "mechanism_score": , "mechanism_reasoning": "", "clinical_evidence_score": , "clinical_reasoning": "", "drug_candidates": ["", ...], "key_findings": ["", ...] }}, "sufficient": , "confidence": , "recommendation": "continue" | "synthesize", "next_search_queries": ["", ...], "reasoning": "" }}""", }, {"role": "user", "content": prompt}, ] # Use chat_completion (conversational task - supported by all models) response = await loop.run_in_executor( None, lambda: self.client.chat_completion( messages=messages, model=model, max_tokens=1024, temperature=0.1, ), ) # Extract content from response content = response.choices[0].message.content if not content: raise ValueError("Empty response from model") # Extract and parse JSON json_data = self._extract_json(content) if not json_data: raise ValueError("No valid JSON found in response") return JudgeAssessment(**json_data) def _extract_json(self, text: str) -> dict[str, Any] | None: """ Robust JSON extraction that handles markdown blocks and nested braces. """ text = text.strip() # Remove markdown code blocks if present (with bounds checking) if "```json" in text: parts = text.split("```json", 1) if len(parts) > 1: inner_parts = parts[1].split("```", 1) text = inner_parts[0] elif "```" in text: parts = text.split("```", 1) if len(parts) > 1: inner_parts = parts[1].split("```", 1) text = inner_parts[0] text = text.strip() # Find first '{' start_idx = text.find("{") if start_idx == -1: return None # Stack-based parsing ignoring chars in strings count = 0 in_string = False escape = False for i, char in enumerate(text[start_idx:], start=start_idx): if in_string: if escape: escape = False elif char == "\\": escape = True elif char == '"': in_string = False elif char == '"': in_string = True elif char == "{": count += 1 elif char == "}": count -= 1 if count == 0: try: result = json.loads(text[start_idx : i + 1]) if isinstance(result, dict): return result return None except json.JSONDecodeError: return None return None def _create_quota_exhausted_assessment( self, question: str, evidence: list[Evidence], ) -> JudgeAssessment: """Create an assessment that stops the loop when quota is exhausted.""" findings = _extract_titles_from_evidence( evidence, max_items=5, fallback_message="No findings available (Quota exceeded and no search results).", ) return JudgeAssessment( details=AssessmentDetails( mechanism_score=0, mechanism_reasoning="Free tier quota exhausted. Unable to analyze mechanism.", clinical_evidence_score=0, clinical_reasoning=( "Free tier quota exhausted. Unable to analyze clinical evidence." ), drug_candidates=["Upgrade to paid API for drug extraction."], key_findings=findings, ), sufficient=True, # STOP THE LOOP confidence=0.0, recommendation="synthesize", next_search_queries=[], reasoning=( "⚠️ **Free Tier Quota Exceeded** ⚠️\n\n" "The HuggingFace Inference API free tier limit has been reached. " "The search results listed below were retrieved but could not be " "analyzed by the AI. " "Please try again later, or add an OpenAI/Anthropic API key above " "for unlimited access." ), ) def _create_forced_synthesis_assessment( self, question: str, evidence: list[Evidence], ) -> JudgeAssessment: """Force synthesis after repeated failures to prevent infinite loops.""" findings = _extract_titles_from_evidence( evidence, max_items=5, fallback_message="No findings available (API failures prevented analysis).", ) return JudgeAssessment( details=AssessmentDetails( mechanism_score=0, mechanism_reasoning="AI analysis unavailable after repeated API failures.", clinical_evidence_score=0, clinical_reasoning="AI analysis unavailable after repeated API failures.", drug_candidates=["AI analysis required for drug identification."], key_findings=findings, ), sufficient=True, # FORCE STOP confidence=0.1, recommendation="synthesize", next_search_queries=[], reasoning=( f"⚠️ **HF Inference Unavailable** ⚠️\n\n" f"The free tier AI service failed {self.MAX_CONSECUTIVE_FAILURES} times. " f"Search found {len(evidence)} sources (listed below) but they could not " "be analyzed by AI.\n\n" "**Options:**\n" "- Add an OpenAI or Anthropic API key for reliable analysis\n" "- Try again later when HF Inference is available\n" "- Review the raw search results below" ), ) def _create_fallback_assessment( self, question: str, error: str, ) -> JudgeAssessment: """Create a fallback assessment when inference fails.""" return JudgeAssessment( details=AssessmentDetails( mechanism_score=0, mechanism_reasoning=f"Assessment failed: {error}", clinical_evidence_score=0, clinical_reasoning=f"Assessment failed: {error}", drug_candidates=[], key_findings=[], ), sufficient=False, confidence=0.0, recommendation="continue", next_search_queries=[ f"{question} mechanism", f"{question} clinical trials", f"{question} drug candidates", ], reasoning=f"HF Inference failed: {error}. Recommend configuring OpenAI/Anthropic key.", ) class MockJudgeHandler: """ Mock JudgeHandler for demo mode without LLM calls. Extracts meaningful information from real search results to provide a useful demo experience without requiring API keys. """ def __init__( self, mock_response: JudgeAssessment | None = None, domain: ResearchDomain | str | None = None, ) -> None: """ Initialize with optional mock response. Args: mock_response: The assessment to return. If None, extracts from evidence. domain: Research domain (ignored in mock but kept for interface compatibility). """ self.mock_response = mock_response self.domain = domain self.call_count = 0 self.last_question: str | None = None self.last_evidence: list[Evidence] | None = None def _extract_key_findings(self, evidence: list[Evidence], max_findings: int = 5) -> list[str]: """Extract key findings from evidence titles.""" # Helper guarantees non-empty list when fallback_message is provided return _extract_titles_from_evidence( evidence, max_items=max_findings, fallback_message="No specific findings extracted (demo mode)", ) def _extract_drug_candidates(self, question: str, evidence: list[Evidence]) -> list[str]: """Extract drug candidates - demo mode returns honest message.""" # Don't attempt heuristic extraction - it produces garbage like "Oral", "Kidney" # Real drug extraction requires LLM analysis return [ "Drug identification requires AI analysis", "Enter API key above for full results", ] async def assess( self, question: str, evidence: list[Evidence], iteration: int = 0, max_iterations: int = 10, ) -> JudgeAssessment: """Return assessment based on actual evidence (demo mode).""" self.call_count += 1 self.last_question = question self.last_evidence = evidence if self.mock_response: return self.mock_response min_evidence = 3 evidence_count = len(evidence) # Extract meaningful data from actual evidence drug_candidates = self._extract_drug_candidates(question, evidence) key_findings = self._extract_key_findings(evidence) # Calculate scores based on evidence quantity mechanism_score = min(10, evidence_count * 2) if evidence_count > 0 else 0 clinical_score = min(10, evidence_count) if evidence_count > 0 else 0 return JudgeAssessment( details=AssessmentDetails( mechanism_score=mechanism_score, mechanism_reasoning=( f"Demo mode: Found {evidence_count} sources. " "Configure LLM API key for detailed mechanism analysis." ), clinical_evidence_score=clinical_score, clinical_reasoning=( f"Demo mode: {evidence_count} sources retrieved from PubMed, " "ClinicalTrials.gov, and Europe PMC. Full analysis requires LLM API key." ), drug_candidates=drug_candidates, key_findings=key_findings, ), sufficient=evidence_count >= min_evidence, confidence=min(0.5, evidence_count * 0.1) if evidence_count > 0 else 0.0, recommendation="synthesize" if evidence_count >= min_evidence else "continue", next_search_queries=( [f"{question} mechanism", f"{question} clinical trials"] if evidence_count < min_evidence else [] ), reasoning=( f"Demo mode assessment based on {evidence_count} real search results. " "For AI-powered analysis with drug candidate identification and " "evidence synthesis, configure OPENAI_API_KEY or ANTHROPIC_API_KEY." ), )