""" Pydantic schemas for API request/response models. """ from typing import List, Optional from pydantic import BaseModel, HttpUrl, Field, field_validator class PromptParameters(BaseModel): """Parameters for customizing the document generation prompt.""" language: str = Field( default="English", description="Language for generated documents" ) doc_type: str = Field( default="business and administrative", description="Type of documents to generate (e.g., 'business and administrative', 'receipts', 'forms')" ) gt_type: str = Field( default="Multiple questions about each document, with their answers taken **verbatim** from the document.", description="Description of ground truth type to generate" ) gt_format: str = Field( default='{"": "", "": "", ...}', description="Format specification for ground truth JSON" ) num_solutions: int = Field( default=1, ge=1, le=5, description="Number of document variations to generate (1-5)" ) # Stage 3: Feature Synthesis parameters enable_handwriting: bool = Field( default=False, description="Enable handwriting generation (requires EC2 handwriting service)" ) handwriting_ratio: float = Field( default=0.2, ge=0.0, le=1.0, description="Proportion of text to convert to handwriting (0.0-1.0)" ) handwriting_apply_ink_filter: bool = Field( default=True, description="Apply high-contrast ink filter to handwriting (v16+ feature)" ) handwriting_enable_enhancements: bool = Field( default=False, description="Enable sharpening and contrast boosting (Experimental)" ) handwriting_num_inference_steps: int = Field( default=1000, ge=1, le=1000, description="Number of diffusion inference steps (1-1000)" ) handwriting_writer_ids: List[int] = Field( default=[404, 347, 156, 253, 354, 166, 320], description="List of writer style IDs to use for handwriting generation" ) enable_visual_elements: bool = Field( default=True, description="Enable visual element generation (stamps, logos, barcodes)" ) visual_element_types: List[str] = Field( default=["stamp", "logo", "figure", "barcode", "photo"], description="Types of visual elements to generate (stamp, logo, figure, barcode, photo)" ) barcode_number: Optional[str] = Field( default=None, description="Optional fixed number for barcode generation (numeric only)" ) seed: Optional[int] = Field( default=None, description="Random seed for reproducible generation", examples=[None, 42] ) # Stage 4: Image Finalization & OCR parameters enable_ocr: bool = Field( default=True, description="Enable OCR on final document images (requires OCR service)" ) ocr_language: str = Field( default="en", description="Language for OCR (e.g., 'en', 'de', 'fr')" ) # Stage 5: Dataset Packaging parameters enable_bbox_normalization: bool = Field( default=True, description="Normalize bounding boxes to [0,1] scale (Stage 16)" ) enable_gt_verification: bool = Field( default=True, description="Verify and prepare ground truth annotations (Stage 17)" ) enable_analysis: bool = Field( default=True, description="Generate dataset statistics and analysis (Stage 18)" ) enable_debug_visualization: bool = Field( default=True, description="Create debug visualization overlays (Stage 19)" ) enable_dataset_export: bool = Field( default=True, description="Export as msgpack dataset format" ) dataset_export_format: str = Field( default="msgpack", description="Dataset export format: 'msgpack', 'coco', 'huggingface'" ) output_detail: str = Field( default="dataset", description="Output detail level: 'minimal' (final outputs only), 'dataset' (includes individual tokens/elements for ML), 'complete' (all intermediate files and debug info). Warning: 'complete' mode can produce 50+ MB responses." ) class SeedImage(BaseModel): """Seed image URL for document generation.""" url: HttpUrl = Field( description="URL of the seed image", default=HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp") ) class GenerateDocumentRequest(BaseModel): """Request schema for document generation endpoint.""" request_id: str = Field( description="Document request UUID from document_requests table (created by frontend)" ) google_drive_token: Optional[str] = Field( default=None, description="Google Drive OAuth access token. Frontend provides this after OAuth flow (optional)." ) google_drive_refresh_token: Optional[str] = Field( default=None, description="Google Drive refresh token (optional, for automatic token renewal)" ) seed_images: List[HttpUrl] = Field( default=[HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp")], description="List of seed image URLs (1-10 images)" ) prompt_params: PromptParameters = Field( default_factory=PromptParameters, description="Parameters for customizing the generation prompt" ) @field_validator('seed_images') @classmethod def validate_seed_images(cls, v): if not v: raise ValueError('At least one seed image is required') if len(v) < 1: raise ValueError('At least one seed image is required') if len(v) > 10: raise ValueError('Maximum 10 seed images allowed') return v class OCRWord(BaseModel): """OCR word-level result.""" text: str = Field(description="Recognized text") confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)") x: float = Field(description="X coordinate (pixels)") y: float = Field(description="Y coordinate (pixels)") width: float = Field(description="Width (pixels)") height: float = Field(description="Height (pixels)") class OCRLine(BaseModel): """OCR line-level result.""" text: str = Field(description="Recognized text") confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)") x: float = Field(description="X coordinate (pixels)") y: float = Field(description="Y coordinate (pixels)") width: float = Field(description="Width (pixels)") height: float = Field(description="Height (pixels)") words: List[OCRWord] = Field(default_factory=list, description="Words in this line") class OCRResult(BaseModel): """OCR results for a document.""" image_width: int = Field(description="Image width in pixels") image_height: int = Field(description="Image height in pixels") words: List[OCRWord] = Field(default_factory=list, description="Word-level OCR results") lines: List[OCRLine] = Field(default_factory=list, description="Line-level OCR results") angle: float = Field(default=0.0, description="Detected text orientation angle") class CostInfo(BaseModel): """Cost information for a request (Research Parity).""" input_tokens: int = Field(description="Number of input tokens") output_tokens: int = Field(description="Number of output tokens") cache_creation_tokens: int = Field(default=0, description="Tokens used for cache creation") cache_read_tokens: int = Field(default=0, description="Tokens read from cache") cost_usd: float = Field(description="Total cost in USD (with 50% batch discount applied if applicable)") batch_discount_applied: bool = Field(default=False, description="Whether 50% batch discount was applied") class NormalizedBBox(BaseModel): """Normalized bounding box (Stage 16).""" text: str = Field(description="Text content") x0: float = Field(ge=0.0, le=1.0, description="Normalized X min (0-1)") y0: float = Field(ge=0.0, le=1.0, description="Normalized Y min (0-1)") x2: float = Field(ge=0.0, le=1.0, description="Normalized X max (0-1)") y2: float = Field(ge=0.0, le=1.0, description="Normalized Y max (0-1)") block_no: Optional[int] = Field(default=None, description="Block number") line_no: Optional[int] = Field(default=None, description="Line number") word_no: Optional[int] = Field(default=None, description="Word number") class GTVerificationResult(BaseModel): """Ground truth verification results (Stage 17).""" passed: bool = Field(description="Whether GT verification passed") skipped: bool = Field(default=False, description="Whether verification was skipped") confirmed_keys: List[str] = Field(default_factory=list, description="Confirmed GT keys") similarities: List[float] = Field(default_factory=list, description="Similarity scores") num_layout_elements: Optional[int] = Field(default=None, description="Number of layout elements") valid_labels: bool = Field(default=True, description="Whether all labels are valid") class AnalysisStats(BaseModel): """Dataset analysis and statistics (Stage 18).""" total_documents: int = Field(description="Total documents processed") valid_documents: int = Field(description="Documents passing all validation") error_counts: dict = Field(default_factory=dict, description="Error type counts") has_handwriting: int = Field(default=0, description="Documents with handwriting") has_visual_elements: int = Field(default=0, description="Documents with visual elements") has_ocr: int = Field(default=0, description="Documents with OCR results") multipage_count: int = Field(default=0, description="Multipage documents") token_usage: Optional[dict] = Field(default=None, description="LLM token usage statistics") class DebugVisualization(BaseModel): """Debug visualization data (Stage 19).""" bbox_overlay_base64: Optional[str] = Field(default=None, description="Image with bbox overlays (PNG base64)") visual_elements_overlay_base64: Optional[str] = Field(default=None, description="Image with visual element overlays") handwriting_overlay_base64: Optional[str] = Field(default=None, description="Image with handwriting overlays") class DatasetExportInfo(BaseModel): """Dataset export metadata.""" format: str = Field(description="Export format (msgpack, coco, etc.)") num_samples: int = Field(description="Number of samples in export") output_path: Optional[str] = Field(default=None, description="Path to exported dataset") msgpack_base64: Optional[str] = Field(default=None, description="Msgpack file as base64 (for small datasets)") metadata: dict = Field(default_factory=dict, description="Dataset metadata") class BoundingBox(BaseModel): """Bounding box for a text element in the document.""" text: str = Field(description="Text content") x: float = Field(description="X coordinate (normalized 0-1)") y: float = Field(description="Y coordinate (normalized 0-1)") width: float = Field(description="Width (normalized 0-1)") height: float = Field(description="Height (normalized 0-1)") page: int = Field(default=0, description="Page number (0-indexed)") class HandwritingRegion(BaseModel): """Information about a handwriting region in the document.""" region_id: str = Field(description="Unique region identifier") text: str = Field(description="Text content") author_id: int = Field(ge=0, le=656, description="Author ID for style consistency (0-656)") bbox: BoundingBox = Field(description="Bounding box of the region") class VisualElement(BaseModel): """Information about a visual element in the document.""" element_id: str = Field(description="Unique element identifier") element_type: str = Field(description="Type of visual element (stamp, logo, etc.)") content: Optional[str] = Field(default=None, description="Content (e.g., stamp text)") bbox: BoundingBox = Field(description="Bounding box of the element") class DocumentResult(BaseModel): """Result for a single generated document.""" document_id: str = Field(description="Unique document identifier") html: str = Field(description="Generated HTML content") css: str = Field(description="Extracted CSS styles") ground_truth: Optional[dict] = Field( default=None, description="Ground truth data extracted from the document" ) pdf_base64: str = Field(description="Base64-encoded PDF document") bboxes: List[BoundingBox] = Field( default_factory=list, description="Bounding boxes for text elements" ) page_width_mm: float = Field(description="Page width in millimeters") page_height_mm: float = Field(description="Page height in millimeters") # Stage 3 additions handwriting_regions: Optional[List[dict]] = Field( default=None, description="Handwriting regions with metadata (if enabled)" ) visual_elements: Optional[List[dict]] = Field( default=None, description="Visual elements with metadata (if enabled)" ) image_base64: Optional[str] = Field( default=None, description="Final rendered image with handwriting/visuals (PNG base64, if Stage 3 enabled)" ) # Stage 3 individual tokens (dataset/complete output detail levels) handwriting_token_images: Optional[dict] = Field( default=None, description="Individual handwriting token images {hw_id: base64_png} (output_detail: dataset/complete)" ) visual_element_images: Optional[dict] = Field( default=None, description="Individual visual element images {ve_id: base64_png} (output_detail: dataset/complete)" ) token_mapping: Optional[dict] = Field( default=None, description="Token mapping with positions and style IDs (output_detail: dataset/complete)" ) # Stage 4 additions ocr_results: Optional[OCRResult] = Field( default=None, description="OCR results from final image (if OCR enabled)" ) # Stage 5 additions normalized_bboxes_word: Optional[List[NormalizedBBox]] = Field( default=None, description="Word-level normalized bounding boxes (if Stage 16 enabled)" ) normalized_bboxes_segment: Optional[List[NormalizedBBox]] = Field( default=None, description="Segment-level normalized bounding boxes (if Stage 16 enabled)" ) gt_verification: Optional[GTVerificationResult] = Field( default=None, description="Ground truth verification results (if Stage 17 enabled)" ) analysis_stats: Optional[AnalysisStats] = Field( default=None, description="Document analysis statistics (if Stage 18 enabled)" ) debug_visualization: Optional[DebugVisualization] = Field( default=None, description="Debug visualization overlays (if Stage 19 enabled)" ) dataset_export: Optional[DatasetExportInfo] = Field( default=None, description="Dataset export information (if export enabled)" ) cost_info: Optional[CostInfo] = Field( default=None, description="Cost information for this document (Research Parity)" ) class GenerateDocumentResponse(BaseModel): """Response schema for document generation endpoint.""" success: bool = Field(description="Whether generation was successful") message: str = Field(description="Status message") documents: List[DocumentResult] = Field( default_factory=list, description="List of generated documents" ) total_documents: int = Field( default=0, description="Total number of documents generated" ) total_cost: Optional[CostInfo] = Field( default=None, description="Aggregated cost for the entire request" ) class HealthResponse(BaseModel): """Health check response.""" status: str = Field(default="healthy") version: str = Field(default="1.0.0")