|
|
""" |
|
|
Pydantic schemas for type safety and validation. |
|
|
""" |
|
|
from datetime import datetime |
|
|
from typing import Any, Dict, List, Optional |
|
|
from pydantic import BaseModel, Field, validator, field_validator |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class Paper(BaseModel): |
|
|
"""Schema for arXiv paper metadata.""" |
|
|
arxiv_id: str = Field(..., description="arXiv paper ID") |
|
|
title: str = Field(..., description="Paper title") |
|
|
authors: List[str] = Field(..., description="List of author names") |
|
|
abstract: str = Field(..., description="Paper abstract") |
|
|
pdf_url: str = Field(..., description="URL to PDF") |
|
|
published: datetime = Field(..., description="Publication date") |
|
|
categories: List[str] = Field(default_factory=list, description="arXiv categories") |
|
|
|
|
|
@validator('authors', pre=True) |
|
|
def normalize_authors(cls, v): |
|
|
"""Ensure authors is always a List[str], handling various input formats.""" |
|
|
if isinstance(v, list): |
|
|
|
|
|
return [str(author) if not isinstance(author, str) else author for author in v] |
|
|
elif isinstance(v, dict): |
|
|
|
|
|
logger.warning(f"Authors field is dict, extracting values: {v}") |
|
|
if 'names' in v: |
|
|
return v['names'] if isinstance(v['names'], list) else [str(v['names'])] |
|
|
elif 'authors' in v: |
|
|
return v['authors'] if isinstance(v['authors'], list) else [str(v['authors'])] |
|
|
else: |
|
|
|
|
|
return [str(val) for val in v.values() if val] |
|
|
elif isinstance(v, str): |
|
|
|
|
|
return [v] |
|
|
else: |
|
|
logger.warning(f"Unexpected authors format: {type(v)}, returning empty list") |
|
|
return [] |
|
|
|
|
|
@validator('categories', pre=True) |
|
|
def normalize_categories(cls, v): |
|
|
"""Ensure categories is always a List[str], handling various input formats.""" |
|
|
if isinstance(v, list): |
|
|
|
|
|
return [str(cat) if not isinstance(cat, str) else cat for cat in v] |
|
|
elif isinstance(v, dict): |
|
|
|
|
|
logger.warning(f"Categories field is dict, extracting values: {v}") |
|
|
if 'categories' in v: |
|
|
return v['categories'] if isinstance(v['categories'], list) else [str(v['categories'])] |
|
|
else: |
|
|
|
|
|
return [str(val) for val in v.values() if val] |
|
|
elif isinstance(v, str): |
|
|
|
|
|
return [v] |
|
|
else: |
|
|
logger.warning(f"Unexpected categories format: {type(v)}, returning empty list") |
|
|
return [] |
|
|
|
|
|
@validator('pdf_url', pre=True) |
|
|
def normalize_pdf_url(cls, v): |
|
|
"""Ensure pdf_url is always a string.""" |
|
|
if isinstance(v, dict): |
|
|
logger.warning(f"pdf_url is dict, extracting url value: {v}") |
|
|
return v.get('url') or v.get('pdf_url') or str(v) |
|
|
return str(v) if v else "" |
|
|
|
|
|
@validator('title', pre=True) |
|
|
def normalize_title(cls, v): |
|
|
"""Ensure title is always a string.""" |
|
|
if isinstance(v, dict): |
|
|
logger.warning(f"title is dict, extracting title value: {v}") |
|
|
return v.get('title') or str(v) |
|
|
return str(v) if v else "" |
|
|
|
|
|
@validator('abstract', pre=True) |
|
|
def normalize_abstract(cls, v): |
|
|
"""Ensure abstract is always a string.""" |
|
|
if isinstance(v, dict): |
|
|
logger.warning(f"abstract is dict, extracting abstract value: {v}") |
|
|
return v.get('abstract') or v.get('summary') or str(v) |
|
|
return str(v) if v else "" |
|
|
|
|
|
class Config: |
|
|
json_encoders = { |
|
|
datetime: lambda v: v.isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
class PaperChunk(BaseModel): |
|
|
"""Schema for chunked paper content.""" |
|
|
chunk_id: str = Field(..., description="Unique chunk identifier") |
|
|
paper_id: str = Field(..., description="arXiv paper ID") |
|
|
content: str = Field(..., description="Chunk text content") |
|
|
section: Optional[str] = Field(None, description="Section name if available") |
|
|
page_number: Optional[int] = Field(None, description="Page number") |
|
|
arxiv_url: str = Field(..., description="arXiv URL for citation") |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") |
|
|
|
|
|
|
|
|
class Analysis(BaseModel): |
|
|
"""Schema for individual paper analysis.""" |
|
|
paper_id: str = Field(..., description="arXiv paper ID") |
|
|
methodology: str = Field(..., description="Research methodology description") |
|
|
key_findings: List[str] = Field(..., description="Main findings from the paper") |
|
|
conclusions: str = Field(..., description="Paper conclusions") |
|
|
limitations: List[str] = Field(..., description="Study limitations") |
|
|
citations: List[str] = Field(..., description="Source locations for claims") |
|
|
main_contributions: List[str] = Field(default_factory=list, description="Key contributions") |
|
|
confidence_score: float = Field(..., ge=0.0, le=1.0, description="Analysis confidence") |
|
|
|
|
|
@field_validator('key_findings', 'limitations', 'citations', 'main_contributions', mode='before') |
|
|
@classmethod |
|
|
def normalize_string_lists(cls, v, info): |
|
|
""" |
|
|
Normalize list fields to ensure they contain only strings. |
|
|
Handles nested lists, None values, and mixed types. |
|
|
""" |
|
|
def flatten_and_clean(value): |
|
|
"""Recursively flatten nested lists and clean values.""" |
|
|
if isinstance(value, str): |
|
|
return [value.strip()] if value.strip() else [] |
|
|
elif isinstance(value, list): |
|
|
cleaned = [] |
|
|
for item in value: |
|
|
if isinstance(item, str): |
|
|
if item.strip(): |
|
|
cleaned.append(item.strip()) |
|
|
elif isinstance(item, list): |
|
|
|
|
|
cleaned.extend(flatten_and_clean(item)) |
|
|
elif item is not None and str(item).strip(): |
|
|
cleaned.append(str(item).strip()) |
|
|
return cleaned |
|
|
elif value is not None: |
|
|
str_value = str(value).strip() |
|
|
return [str_value] if str_value else [] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
result = flatten_and_clean(v) |
|
|
if v != result: |
|
|
logger.warning(f"Normalized '{info.field_name}' in Analysis: cleaned nested/invalid values") |
|
|
return result |
|
|
|
|
|
|
|
|
class ConsensusPoint(BaseModel): |
|
|
"""Schema for consensus findings across papers.""" |
|
|
statement: str = Field(..., description="Consensus statement") |
|
|
supporting_papers: List[str] = Field(..., description="Paper IDs supporting this claim") |
|
|
citations: List[str] = Field(..., description="Specific citations") |
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence in consensus") |
|
|
|
|
|
@field_validator('supporting_papers', 'citations', mode='before') |
|
|
@classmethod |
|
|
def normalize_string_lists(cls, v, info): |
|
|
"""Normalize list fields to ensure they contain only strings.""" |
|
|
def flatten_and_clean(value): |
|
|
if isinstance(value, str): |
|
|
return [value.strip()] if value.strip() else [] |
|
|
elif isinstance(value, list): |
|
|
cleaned = [] |
|
|
for item in value: |
|
|
if isinstance(item, str) and item.strip(): |
|
|
cleaned.append(item.strip()) |
|
|
elif isinstance(item, list): |
|
|
cleaned.extend(flatten_and_clean(item)) |
|
|
elif item is not None and str(item).strip(): |
|
|
cleaned.append(str(item).strip()) |
|
|
return cleaned |
|
|
elif value is not None: |
|
|
str_value = str(value).strip() |
|
|
return [str_value] if str_value else [] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
result = flatten_and_clean(v) |
|
|
if v != result: |
|
|
logger.warning(f"Normalized '{info.field_name}' in ConsensusPoint: cleaned nested/invalid values") |
|
|
return result |
|
|
|
|
|
|
|
|
class Contradiction(BaseModel): |
|
|
"""Schema for contradictory findings.""" |
|
|
topic: str = Field(..., description="Topic of contradiction") |
|
|
viewpoint_a: str = Field(..., description="First viewpoint") |
|
|
papers_a: List[str] = Field(..., description="Papers supporting viewpoint A") |
|
|
viewpoint_b: str = Field(..., description="Second viewpoint") |
|
|
papers_b: List[str] = Field(..., description="Papers supporting viewpoint B") |
|
|
citations: List[str] = Field(..., description="Specific citations for both sides") |
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence in contradiction") |
|
|
|
|
|
@field_validator('papers_a', 'papers_b', 'citations', mode='before') |
|
|
@classmethod |
|
|
def normalize_string_lists(cls, v, info): |
|
|
"""Normalize list fields to ensure they contain only strings.""" |
|
|
def flatten_and_clean(value): |
|
|
if isinstance(value, str): |
|
|
return [value.strip()] if value.strip() else [] |
|
|
elif isinstance(value, list): |
|
|
cleaned = [] |
|
|
for item in value: |
|
|
if isinstance(item, str) and item.strip(): |
|
|
cleaned.append(item.strip()) |
|
|
elif isinstance(item, list): |
|
|
cleaned.extend(flatten_and_clean(item)) |
|
|
elif item is not None and str(item).strip(): |
|
|
cleaned.append(str(item).strip()) |
|
|
return cleaned |
|
|
elif value is not None: |
|
|
str_value = str(value).strip() |
|
|
return [str_value] if str_value else [] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
result = flatten_and_clean(v) |
|
|
if v != result: |
|
|
logger.warning(f"Normalized '{info.field_name}' in Contradiction: cleaned nested/invalid values") |
|
|
return result |
|
|
|
|
|
|
|
|
class SynthesisResult(BaseModel): |
|
|
"""Schema for synthesis across multiple papers.""" |
|
|
consensus_points: List[ConsensusPoint] = Field(..., description="Areas of agreement") |
|
|
contradictions: List[Contradiction] = Field(..., description="Areas of disagreement") |
|
|
research_gaps: List[str] = Field(..., description="Identified research gaps") |
|
|
summary: str = Field(..., description="Executive summary") |
|
|
confidence_score: float = Field(..., ge=0.0, le=1.0, description="Overall confidence") |
|
|
papers_analyzed: List[str] = Field(..., description="List of paper IDs analyzed") |
|
|
|
|
|
@field_validator('research_gaps', 'papers_analyzed', mode='before') |
|
|
@classmethod |
|
|
def normalize_string_lists(cls, v, info): |
|
|
"""Normalize list fields to ensure they contain only strings.""" |
|
|
def flatten_and_clean(value): |
|
|
if isinstance(value, str): |
|
|
return [value.strip()] if value.strip() else [] |
|
|
elif isinstance(value, list): |
|
|
cleaned = [] |
|
|
for item in value: |
|
|
if isinstance(item, str) and item.strip(): |
|
|
cleaned.append(item.strip()) |
|
|
elif isinstance(item, list): |
|
|
cleaned.extend(flatten_and_clean(item)) |
|
|
elif item is not None and str(item).strip(): |
|
|
cleaned.append(str(item).strip()) |
|
|
return cleaned |
|
|
elif value is not None: |
|
|
str_value = str(value).strip() |
|
|
return [str_value] if str_value else [] |
|
|
else: |
|
|
return [] |
|
|
|
|
|
result = flatten_and_clean(v) |
|
|
if v != result: |
|
|
logger.warning(f"Normalized '{info.field_name}' in SynthesisResult: cleaned nested/invalid values") |
|
|
return result |
|
|
|
|
|
|
|
|
class Citation(BaseModel): |
|
|
"""Schema for properly formatted citations.""" |
|
|
paper_id: str = Field(..., description="arXiv paper ID") |
|
|
authors: List[str] = Field(..., description="Paper authors") |
|
|
year: int = Field(..., description="Publication year") |
|
|
title: str = Field(..., description="Paper title") |
|
|
source: str = Field(..., description="Publication source (arXiv)") |
|
|
apa_format: str = Field(..., description="Full APA formatted citation") |
|
|
url: str = Field(..., description="arXiv URL") |
|
|
|
|
|
|
|
|
class ValidatedOutput(BaseModel): |
|
|
"""Schema for final validated output with citations.""" |
|
|
synthesis: SynthesisResult = Field(..., description="Synthesis results") |
|
|
citations: List[Citation] = Field(..., description="All citations used") |
|
|
retrieved_chunks: List[str] = Field(..., description="Chunk IDs used for grounding") |
|
|
token_usage: Dict[str, int] = Field(default_factory=dict, description="Token usage stats") |
|
|
model_desc: Dict[str, str] = Field(default_factory=dict, description="Model descriptions") |
|
|
cost_estimate: float = Field(..., description="Estimated cost in USD") |
|
|
processing_time: float = Field(..., description="Processing time in seconds") |
|
|
|
|
|
|
|
|
class AgentState(BaseModel): |
|
|
""" |
|
|
Schema for LangGraph state management. |
|
|
|
|
|
Note: This Pydantic model serves as type documentation and validation reference. |
|
|
The actual LangGraph workflow in app.py uses Dict[str, Any] for state to maintain |
|
|
compatibility with Gradio progress tracking and dynamic state updates during execution. |
|
|
|
|
|
All fields in this schema correspond to keys in the workflow state dictionary. |
|
|
""" |
|
|
query: str = Field(..., description="User research question") |
|
|
category: Optional[str] = Field(None, description="arXiv category filter") |
|
|
num_papers: int = Field(default=5, ge=1, le=20, description="Number of papers to retrieve") |
|
|
papers: List[Paper] = Field(default_factory=list, description="Retrieved papers") |
|
|
chunks: List[PaperChunk] = Field(default_factory=list, description="Chunked content") |
|
|
analyses: List[Analysis] = Field(default_factory=list, description="Individual analyses") |
|
|
synthesis: Optional[SynthesisResult] = Field(None, description="Synthesis result") |
|
|
validated_output: Optional[ValidatedOutput] = Field(None, description="Final output") |
|
|
errors: List[str] = Field(default_factory=list, description="Error messages") |
|
|
|
|
|
class Config: |
|
|
arbitrary_types_allowed = True |
|
|
|