File size: 6,327 Bytes
df4a21a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
Pydantic schemas for prediction endpoints.
"""
from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, Field
# Model display info schema (for frontend)
class ModelDisplayInfo(BaseModel):
"""User-friendly display information for a model."""
display_name: str = Field(..., description="User-friendly model name (e.g., 'Texture Analysis')")
short_name: str = Field(..., description="Short identifier (e.g., 'CNN')")
method_name: str = Field(..., description="Explainability method name (e.g., 'Grad-CAM')")
method_description: str = Field(..., description="Brief description of the method")
educational_text: str = Field(..., description="Educational text about what this model analyzes")
what_it_looks_for: List[str] = Field(..., description="List of things this model looks for")
# LLM single-model explanation schema (for on-demand requests)
class SingleModelInsight(BaseModel):
"""LLM-generated insight for a single model (on-demand)."""
key_finding: str = Field(..., description="Main finding from the model")
what_model_saw: str = Field(..., description="What the model detected in the image")
important_regions: List[str] = Field(..., description="Key regions identified")
confidence_qualifier: str = Field(..., description="Confidence assessment with hedging")
class PredictionResult(BaseModel):
"""Single prediction result from a model."""
pred: Literal["real", "fake"] = Field(
...,
description="Human-readable prediction label"
)
pred_int: Literal[0, 1] = Field(
...,
description="Integer prediction: 0=real, 1=fake"
)
prob_fake: float = Field(
...,
ge=0.0,
le=1.0,
description="Probability that the image is fake (0.0-1.0)"
)
heatmap_base64: Optional[str] = Field(
None,
description="Base64-encoded PNG heatmap showing model attention/saliency (when explain=true)"
)
explainability_type: Optional[Literal["grad_cam", "attention_rollout"]] = Field(
None,
description="Type of explainability method used"
)
focus_summary: Optional[str] = Field(
None,
description="Brief description of where the model focused (e.g., 'concentrated on face region')"
)
contribution_percentage: Optional[float] = Field(
None,
ge=0.0,
le=100.0,
description="How much this model contributed to the fusion decision (0-100%)"
)
class FusionMeta(BaseModel):
"""Metadata from fusion model about how decision was made."""
submodel_weights: Dict[str, float] = Field(
default_factory=dict,
description="Learned coefficients for each submodel"
)
weighted_contributions: Dict[str, float] = Field(
default_factory=dict,
description="Actual contribution to this prediction (weight * prob_fake)"
)
contribution_percentages: Dict[str, float] = Field(
default_factory=dict,
description="Normalized percentages for display"
)
class TimingInfo(BaseModel):
"""Timing breakdown for the prediction request."""
total: int = Field(..., description="Total time in milliseconds")
download: Optional[int] = Field(None, description="Image download time in ms")
preprocess: Optional[int] = Field(None, description="Preprocessing time in ms")
inference: Optional[int] = Field(None, description="Model inference time in ms")
fusion: Optional[int] = Field(None, description="Fusion computation time in ms")
class PredictResponse(BaseModel):
"""Response schema for prediction endpoint."""
final: PredictionResult = Field(
...,
description="Final prediction result"
)
fusion_used: bool = Field(
...,
description="Whether fusion was used for this prediction"
)
submodels: Optional[Dict[str, PredictionResult]] = Field(
None,
description="Individual submodel predictions (when fusion_used=true and return_submodels=true)"
)
fusion_meta: Optional[FusionMeta] = Field(
None,
description="Fusion metadata including model weights and contributions"
)
model_display_info: Optional[Dict[str, ModelDisplayInfo]] = Field(
None,
description="Display information for each model (for frontend rendering)"
)
timing_ms: TimingInfo = Field(
...,
description="Timing breakdown in milliseconds"
)
class Config:
json_schema_extra = {
"example": {
"final": {"pred": "fake", "pred_int": 1, "prob_fake": 0.6667},
"fusion_used": True,
"submodels": {
"cnn-transfer": {"pred": "fake", "pred_int": 1, "prob_fake": 0.82, "contribution_percentage": 45.2},
"vit-base": {"pred": "fake", "pred_int": 1, "prob_fake": 0.75, "contribution_percentage": 32.1},
"gradfield-cnn": {"pred": "fake", "pred_int": 1, "prob_fake": 0.91, "contribution_percentage": 22.7}
},
"timing_ms": {"total": 250, "inference": 200, "fusion": 5}
}
}
# Request/Response for single-model explanation endpoint
class ExplainModelRequest(BaseModel):
"""Request schema for single-model explanation."""
model_name: str = Field(..., description="Name of the model to explain")
prob_fake: float = Field(..., ge=0.0, le=1.0, description="Model's fake probability")
heatmap_base64: Optional[str] = Field(None, description="Base64-encoded heatmap")
focus_summary: Optional[str] = Field(None, description="Where the model focused")
contribution_percentage: Optional[float] = Field(None, description="Model's contribution to fusion")
class ExplainModelResponse(BaseModel):
"""Response schema for single-model explanation."""
model_name: str = Field(..., description="Internal model name")
insight: SingleModelInsight = Field(..., description="LLM-generated insight")
class ErrorResponse(BaseModel):
"""Error response schema."""
error: str = Field(..., description="Error type")
message: str = Field(..., description="Error message")
details: Optional[Dict] = Field(None, description="Additional error details")
|