Spaces:
Running
Running
File size: 10,103 Bytes
b6f9fa8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 | """
src/api/schemas.py β Pydantic request/response models for MediRAG FastAPI
=========================================================================
FR-18: Input validation limits from config.yaml β api:
- max_query_length: 500
- max_answer_length: 2000
- max_chunks: 10
- max_chunk_length: 2000
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
class IngestRequest(BaseModel):
"""POST /ingest β append a custom document to the FAISS index."""
title: str = Field(..., description="Document title")
text: str = Field(..., min_length=10, description="Raw text of the document to ingest")
pub_type: str = Field(default="clinical_guideline", description="Document type")
source: str = Field(default="custom_upload", description="Source of the document")
# ---------------------------------------------------------------------------
# Request schemas
# ---------------------------------------------------------------------------
class ContextChunk(BaseModel):
"""A single retrieved context chunk passed to the evaluation pipeline."""
text: str = Field(..., min_length=1, max_length=2000,
description="Chunk text (max 2000 chars)")
# Optional metadata fields β all pass-through to the pipeline modules
chunk_id: Optional[str] = None
pub_type: Optional[str] = None
pub_year: Optional[int] = None
source: Optional[str] = None
title: Optional[str] = None
tier_type: Optional[str] = None # pre-labelled evidence tier (optional)
score: Optional[float] = None # retrieval similarity score
class EvaluateRequest(BaseModel):
"""POST /evaluate β request body."""
question: str = Field(
...,
min_length=5,
max_length=500,
description="User question (5β500 chars)",
examples=["What is the recommended dosage of Metformin for Type 2 Diabetes in elderly patients?"],
)
answer: str = Field(
...,
min_length=1,
max_length=2000,
description="LLM-generated answer to evaluate (1β2000 chars)",
examples=["Metformin is typically started at 500 mg twice daily with meals..."],
)
context_chunks: List[ContextChunk] = Field(
...,
min_length=1,
max_length=10,
description="Retrieved context chunks (1β10 items)",
)
run_ragas: bool = Field(
default=False,
description="Run RAGAS evaluation (requires Ollama or OpenAI backend; slower)",
)
llm_provider: Optional[str] = Field(
default=None,
description="LLM provider override: 'gemini' or 'ollama'"
)
llm_api_key: Optional[str] = Field(
default=None,
description="API Key if accessing Gemini"
)
llm_model: Optional[str] = Field(
default=None,
description="Specific model string if overriding defaults"
)
rxnorm_cache_path: str = Field(
default="data/rxnorm_cache.csv",
description="Path to RxNorm cache CSV",
)
@field_validator("context_chunks")
@classmethod
def at_least_one_chunk(cls, v: list) -> list:
if len(v) == 0:
raise ValueError("At least one context chunk is required")
return v
# ---------------------------------------------------------------------------
# Response schemas
# ---------------------------------------------------------------------------
class ModuleScore(BaseModel):
"""Score + details dict for a single evaluation module."""
score: float = Field(..., ge=0.0, le=1.0, description="Module score in [0, 1]")
details: Dict[str, Any] = Field(default_factory=dict)
error: Optional[str] = Field(None, description="Error message if module failed")
latency_ms: Optional[int] = None
class ModuleResults(BaseModel):
"""All per-module scores bundled together."""
faithfulness: Optional[ModuleScore] = None
entity_verifier: Optional[ModuleScore] = None
source_credibility: Optional[ModuleScore] = None
contradiction: Optional[ModuleScore] = None
ragas: Optional[ModuleScore] = None
class EvaluateResponse(BaseModel):
"""POST /evaluate β response body (FR-17 format)."""
composite_score: float = Field(
..., ge=0.0, le=1.0,
description="Weighted composite score in [0, 1]"
)
hrs: int = Field(
..., ge=0, le=100,
description="Health Risk Score = round(100 Γ (1 - composite_score))"
)
confidence_level: str = Field(
...,
description="HIGH / MODERATE / LOW",
)
risk_band: str = Field(
...,
description="LOW / MODERATE / HIGH / CRITICAL",
)
module_results: ModuleResults
total_pipeline_ms: int = Field(..., description="Total wall-clock time in ms")
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
system_prompt: Optional[str] = None
persona: Optional[str] = "physician"
class HealthResponse(BaseModel):
"""GET /health β liveness and dependency status."""
status: str = Field(default="ok")
ollama_available: bool
version: str = Field(default="0.1.0")
# ---------------------------------------------------------------------------
# End-to-end query schemas (POST /query)
# ---------------------------------------------------------------------------
class QueryRequest(BaseModel):
"""POST /query β only a question needed; retrieval + generation happen server-side."""
question: str = Field(
...,
min_length=5,
max_length=8000,
description="Medical question (5β8000 chars; may include doc context)",
examples=["What is the recommended dosage of Metformin for elderly Type 2 Diabetes patients?"],
)
top_k: int = Field(
default=5,
ge=1,
le=10,
description="Number of context chunks to retrieve (1β10)",
)
run_ragas: bool = Field(
default=False,
description="Run RAGAS evaluation (requires LLM backend)",
)
# Per-request LLM overrides β if not set, server config.yaml values are used
# This makes the eval engine portable: callers bring their own key + model
llm_provider: Optional[str] = Field(
default=None,
description="LLM provider override: 'gemini' or 'ollama'"
)
llm_api_key: Optional[str] = Field(
default=None,
description="API key override (e.g. Gemini key). Not logged or stored."
)
llm_model: Optional[str] = Field(
default=None,
description="Model name override (e.g. 'gemini-2.5-flash-lite')"
)
ollama_url: Optional[str] = Field(
default=None,
description="Ollama base URL override (e.g. 'http://localhost:11434')"
)
# Demo/test only β injects a false claim into the LLM answer before evaluation
# to demonstrate the intervention system catching hallucinations.
inject_hallucination: Optional[str] = Field(
default=None,
description="[DEMO ONLY] Appends a false medical claim to the answer before evaluation."
)
# Consensus Engine (Option 2)
use_consensus: bool = Field(
default=False,
description="Run multiple models and compare for clinical agreement."
)
# Privacy Shield (Option 1)
use_privacy_shield: bool = Field(
default=False,
description="Automatically redact PHI/PII (names, IDs) before external API calls.",
)
system_prompt: Optional[str] = Field(
default=None,
description="Custom system prompt to override the default clinical persona."
)
persona: Optional[str] = Field(
default="physician",
description="The target audience for the response: 'physician' or 'patient'."
)
class RetrievedChunk(BaseModel):
"""A single chunk returned alongside the query response for transparency."""
chunk_id: Optional[str] = None
text: str
source: Optional[str] = None
pub_type: Optional[str] = None
pub_year: Optional[int] = None
title: Optional[str] = None
similarity_score: Optional[float] = None
class QueryResponse(BaseModel):
"""POST /query β full end-to-end response."""
question: str
generated_answer: str
retrieved_chunks: List[RetrievedChunk]
# Evaluation fields (same as EvaluateResponse)
composite_score: float = Field(..., ge=0.0, le=1.0)
hrs: int = Field(..., ge=0, le=100)
confidence_level: str
risk_band: str
module_results: ModuleResults
total_pipeline_ms: int
# Intervention fields (active safety gate)
intervention_applied: bool = Field(
default=False,
description="True if the system modified or blocked the response for safety.",
)
intervention_reason: Optional[str] = Field(
default=None,
description="CRITICAL_BLOCKED | HIGH_RISK_REGENERATED | null",
)
original_answer: Optional[str] = Field(
default=None,
description="The original (unsafe) LLM answer before intervention, for transparency.",
)
intervention_details: Optional[Dict[str, Any]] = Field(
default=None,
description="Which modules triggered the intervention and their scores.",
)
# Consensus fields
consensus_results: Optional[Dict[str, Any]] = Field(
default=None,
description="Results from the multi-model agreement check."
)
# Privacy Shield fields
privacy_applied: bool = Field(default=False)
privacy_details: Optional[Dict[str, Any]] = Field(default=None)
# Coverage gap gate β distinguishes missing DB coverage from hallucination
coverage_gap: bool = Field(
default=False,
description="True when retrieval quality is low β the database may lack coverage for this topic.",
)
coverage_gap_details: Optional[Dict[str, Any]] = Field(
default=None,
description="gap_type (COVERAGE_GAP | HALLUCINATION), retrieval_confidence, threshold.",
)
|