| """ |
| Pydantic schemas for API request/response models. |
| """ |
| from typing import List, Optional |
| from pydantic import BaseModel, HttpUrl, Field, field_validator |
|
|
|
|
| class PromptParameters(BaseModel): |
| """Parameters for customizing the document generation prompt.""" |
| language: str = Field( |
| default="English", |
| description="Language for generated documents" |
| ) |
| doc_type: str = Field( |
| default="business and administrative", |
| description="Type of documents to generate (e.g., 'business and administrative', 'receipts', 'forms')" |
| ) |
| gt_type: str = Field( |
| default="Multiple questions about each document, with their answers taken **verbatim** from the document.", |
| description="Description of ground truth type to generate" |
| ) |
| gt_format: str = Field( |
| default='{"<Text of question 1>": "<Answer to question 1>", "<Text of question 2>": "<Answer to question 2>", ...}', |
| description="Format specification for ground truth JSON" |
| ) |
| num_solutions: int = Field( |
| default=1, |
| ge=1, |
| le=5, |
| description="Number of document variations to generate (1-5)" |
| ) |
| |
| enable_handwriting: bool = Field( |
| default=False, |
| description="Enable handwriting generation (requires EC2 handwriting service)" |
| ) |
| handwriting_ratio: float = Field( |
| default=0.2, |
| ge=0.0, |
| le=1.0, |
| description="Proportion of text to convert to handwriting (0.0-1.0)" |
| ) |
| handwriting_apply_ink_filter: bool = Field( |
| default=True, |
| description="Apply high-contrast ink filter to handwriting (v16+ feature)" |
| ) |
| handwriting_enable_enhancements: bool = Field( |
| default=False, |
| description="Enable sharpening and contrast boosting (Experimental)" |
| ) |
| handwriting_num_inference_steps: int = Field( |
| default=1000, |
| ge=1, |
| le=1000, |
| description="Number of diffusion inference steps (1-1000)" |
| ) |
| handwriting_writer_ids: List[int] = Field( |
| default=[404, 347, 156, 253, 354, 166, 320], |
| description="List of writer style IDs to use for handwriting generation" |
| ) |
| enable_visual_elements: bool = Field( |
| default=True, |
| description="Enable visual element generation (stamps, logos, barcodes)" |
| ) |
| visual_element_types: List[str] = Field( |
| default=["stamp", "logo", "figure", "barcode", "photo"], |
| description="Types of visual elements to generate (stamp, logo, figure, barcode, photo)" |
| ) |
| barcode_number: Optional[str] = Field( |
| default=None, |
| description="Optional fixed number for barcode generation (numeric only)" |
| ) |
| seed: Optional[int] = Field( |
| default=None, |
| description="Random seed for reproducible generation", |
| examples=[None, 42] |
| ) |
| |
| enable_ocr: bool = Field( |
| default=True, |
| description="Enable OCR on final document images (requires OCR service)" |
| ) |
| ocr_language: str = Field( |
| default="en", |
| description="Language for OCR (e.g., 'en', 'de', 'fr')" |
| ) |
| |
| enable_bbox_normalization: bool = Field( |
| default=True, |
| description="Normalize bounding boxes to [0,1] scale (Stage 16)" |
| ) |
| enable_gt_verification: bool = Field( |
| default=True, |
| description="Verify and prepare ground truth annotations (Stage 17)" |
| ) |
| enable_analysis: bool = Field( |
| default=True, |
| description="Generate dataset statistics and analysis (Stage 18)" |
| ) |
| enable_debug_visualization: bool = Field( |
| default=True, |
| description="Create debug visualization overlays (Stage 19)" |
| ) |
| enable_dataset_export: bool = Field( |
| default=True, |
| description="Export as msgpack dataset format" |
| ) |
| dataset_export_format: str = Field( |
| default="msgpack", |
| description="Dataset export format: 'msgpack', 'coco', 'huggingface'" |
| ) |
| output_detail: str = Field( |
| default="dataset", |
| description="Output detail level: 'minimal' (final outputs only), 'dataset' (includes individual tokens/elements for ML), 'complete' (all intermediate files and debug info). Warning: 'complete' mode can produce 50+ MB responses." |
| ) |
|
|
|
|
| class SeedImage(BaseModel): |
| """Seed image URL for document generation.""" |
| url: HttpUrl = Field( |
| description="URL of the seed image", |
| default=HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp") |
| ) |
|
|
|
|
| class GenerateDocumentRequest(BaseModel): |
| """Request schema for document generation endpoint.""" |
| request_id: str = Field( |
| description="Document request UUID from document_requests table (created by frontend)" |
| ) |
| google_drive_token: Optional[str] = Field( |
| default=None, |
| description="Google Drive OAuth access token. Frontend provides this after OAuth flow (optional)." |
| ) |
| google_drive_refresh_token: Optional[str] = Field( |
| default=None, |
| description="Google Drive refresh token (optional, for automatic token renewal)" |
| ) |
| seed_images: List[HttpUrl] = Field( |
| default=[HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp")], |
| description="List of seed image URLs (1-10 images)" |
| ) |
| prompt_params: PromptParameters = Field( |
| default_factory=PromptParameters, |
| description="Parameters for customizing the generation prompt" |
| ) |
| |
| @field_validator('seed_images') |
| @classmethod |
| def validate_seed_images(cls, v): |
| if not v: |
| raise ValueError('At least one seed image is required') |
| if len(v) < 1: |
| raise ValueError('At least one seed image is required') |
| if len(v) > 10: |
| raise ValueError('Maximum 10 seed images allowed') |
| return v |
|
|
|
|
| class OCRWord(BaseModel): |
| """OCR word-level result.""" |
| text: str = Field(description="Recognized text") |
| confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)") |
| x: float = Field(description="X coordinate (pixels)") |
| y: float = Field(description="Y coordinate (pixels)") |
| width: float = Field(description="Width (pixels)") |
| height: float = Field(description="Height (pixels)") |
|
|
|
|
| class OCRLine(BaseModel): |
| """OCR line-level result.""" |
| text: str = Field(description="Recognized text") |
| confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)") |
| x: float = Field(description="X coordinate (pixels)") |
| y: float = Field(description="Y coordinate (pixels)") |
| width: float = Field(description="Width (pixels)") |
| height: float = Field(description="Height (pixels)") |
| words: List[OCRWord] = Field(default_factory=list, description="Words in this line") |
|
|
|
|
| class OCRResult(BaseModel): |
| """OCR results for a document.""" |
| image_width: int = Field(description="Image width in pixels") |
| image_height: int = Field(description="Image height in pixels") |
| words: List[OCRWord] = Field(default_factory=list, description="Word-level OCR results") |
| lines: List[OCRLine] = Field(default_factory=list, description="Line-level OCR results") |
| angle: float = Field(default=0.0, description="Detected text orientation angle") |
|
|
|
|
| class CostInfo(BaseModel): |
| """Cost information for a request (Research Parity).""" |
| input_tokens: int = Field(description="Number of input tokens") |
| output_tokens: int = Field(description="Number of output tokens") |
| cache_creation_tokens: int = Field(default=0, description="Tokens used for cache creation") |
| cache_read_tokens: int = Field(default=0, description="Tokens read from cache") |
| cost_usd: float = Field(description="Total cost in USD (with 50% batch discount applied if applicable)") |
| batch_discount_applied: bool = Field(default=False, description="Whether 50% batch discount was applied") |
|
|
|
|
| class NormalizedBBox(BaseModel): |
| """Normalized bounding box (Stage 16).""" |
| text: str = Field(description="Text content") |
| x0: float = Field(ge=0.0, le=1.0, description="Normalized X min (0-1)") |
| y0: float = Field(ge=0.0, le=1.0, description="Normalized Y min (0-1)") |
| x2: float = Field(ge=0.0, le=1.0, description="Normalized X max (0-1)") |
| y2: float = Field(ge=0.0, le=1.0, description="Normalized Y max (0-1)") |
| block_no: Optional[int] = Field(default=None, description="Block number") |
| line_no: Optional[int] = Field(default=None, description="Line number") |
| word_no: Optional[int] = Field(default=None, description="Word number") |
|
|
|
|
| class GTVerificationResult(BaseModel): |
| """Ground truth verification results (Stage 17).""" |
| passed: bool = Field(description="Whether GT verification passed") |
| skipped: bool = Field(default=False, description="Whether verification was skipped") |
| confirmed_keys: List[str] = Field(default_factory=list, description="Confirmed GT keys") |
| similarities: List[float] = Field(default_factory=list, description="Similarity scores") |
| num_layout_elements: Optional[int] = Field(default=None, description="Number of layout elements") |
| valid_labels: bool = Field(default=True, description="Whether all labels are valid") |
|
|
|
|
| class AnalysisStats(BaseModel): |
| """Dataset analysis and statistics (Stage 18).""" |
| total_documents: int = Field(description="Total documents processed") |
| valid_documents: int = Field(description="Documents passing all validation") |
| error_counts: dict = Field(default_factory=dict, description="Error type counts") |
| has_handwriting: int = Field(default=0, description="Documents with handwriting") |
| has_visual_elements: int = Field(default=0, description="Documents with visual elements") |
| has_ocr: int = Field(default=0, description="Documents with OCR results") |
| multipage_count: int = Field(default=0, description="Multipage documents") |
| token_usage: Optional[dict] = Field(default=None, description="LLM token usage statistics") |
|
|
|
|
| class DebugVisualization(BaseModel): |
| """Debug visualization data (Stage 19).""" |
| bbox_overlay_base64: Optional[str] = Field(default=None, description="Image with bbox overlays (PNG base64)") |
| visual_elements_overlay_base64: Optional[str] = Field(default=None, description="Image with visual element overlays") |
| handwriting_overlay_base64: Optional[str] = Field(default=None, description="Image with handwriting overlays") |
|
|
|
|
| class DatasetExportInfo(BaseModel): |
| """Dataset export metadata.""" |
| format: str = Field(description="Export format (msgpack, coco, etc.)") |
| num_samples: int = Field(description="Number of samples in export") |
| output_path: Optional[str] = Field(default=None, description="Path to exported dataset") |
| msgpack_base64: Optional[str] = Field(default=None, description="Msgpack file as base64 (for small datasets)") |
| metadata: dict = Field(default_factory=dict, description="Dataset metadata") |
|
|
|
|
| class BoundingBox(BaseModel): |
| """Bounding box for a text element in the document.""" |
| text: str = Field(description="Text content") |
| x: float = Field(description="X coordinate (normalized 0-1)") |
| y: float = Field(description="Y coordinate (normalized 0-1)") |
| width: float = Field(description="Width (normalized 0-1)") |
| height: float = Field(description="Height (normalized 0-1)") |
| page: int = Field(default=0, description="Page number (0-indexed)") |
|
|
|
|
| class HandwritingRegion(BaseModel): |
| """Information about a handwriting region in the document.""" |
| region_id: str = Field(description="Unique region identifier") |
| text: str = Field(description="Text content") |
| author_id: int = Field(ge=0, le=656, description="Author ID for style consistency (0-656)") |
| bbox: BoundingBox = Field(description="Bounding box of the region") |
|
|
|
|
| class VisualElement(BaseModel): |
| """Information about a visual element in the document.""" |
| element_id: str = Field(description="Unique element identifier") |
| element_type: str = Field(description="Type of visual element (stamp, logo, etc.)") |
| content: Optional[str] = Field(default=None, description="Content (e.g., stamp text)") |
| bbox: BoundingBox = Field(description="Bounding box of the element") |
|
|
|
|
| class DocumentResult(BaseModel): |
| """Result for a single generated document.""" |
| document_id: str = Field(description="Unique document identifier") |
| html: str = Field(description="Generated HTML content") |
| css: str = Field(description="Extracted CSS styles") |
| ground_truth: Optional[dict] = Field( |
| default=None, |
| description="Ground truth data extracted from the document" |
| ) |
| pdf_base64: str = Field(description="Base64-encoded PDF document") |
| bboxes: List[BoundingBox] = Field( |
| default_factory=list, |
| description="Bounding boxes for text elements" |
| ) |
| page_width_mm: float = Field(description="Page width in millimeters") |
| page_height_mm: float = Field(description="Page height in millimeters") |
| |
| handwriting_regions: Optional[List[dict]] = Field( |
| default=None, |
| description="Handwriting regions with metadata (if enabled)" |
| ) |
| visual_elements: Optional[List[dict]] = Field( |
| default=None, |
| description="Visual elements with metadata (if enabled)" |
| ) |
| image_base64: Optional[str] = Field( |
| default=None, |
| description="Final rendered image with handwriting/visuals (PNG base64, if Stage 3 enabled)" |
| ) |
| |
| handwriting_token_images: Optional[dict] = Field( |
| default=None, |
| description="Individual handwriting token images {hw_id: base64_png} (output_detail: dataset/complete)" |
| ) |
| visual_element_images: Optional[dict] = Field( |
| default=None, |
| description="Individual visual element images {ve_id: base64_png} (output_detail: dataset/complete)" |
| ) |
| token_mapping: Optional[dict] = Field( |
| default=None, |
| description="Token mapping with positions and style IDs (output_detail: dataset/complete)" |
| ) |
| |
| ocr_results: Optional[OCRResult] = Field( |
| default=None, |
| description="OCR results from final image (if OCR enabled)" |
| ) |
| |
| normalized_bboxes_word: Optional[List[NormalizedBBox]] = Field( |
| default=None, |
| description="Word-level normalized bounding boxes (if Stage 16 enabled)" |
| ) |
| normalized_bboxes_segment: Optional[List[NormalizedBBox]] = Field( |
| default=None, |
| description="Segment-level normalized bounding boxes (if Stage 16 enabled)" |
| ) |
| gt_verification: Optional[GTVerificationResult] = Field( |
| default=None, |
| description="Ground truth verification results (if Stage 17 enabled)" |
| ) |
| analysis_stats: Optional[AnalysisStats] = Field( |
| default=None, |
| description="Document analysis statistics (if Stage 18 enabled)" |
| ) |
| debug_visualization: Optional[DebugVisualization] = Field( |
| default=None, |
| description="Debug visualization overlays (if Stage 19 enabled)" |
| ) |
| dataset_export: Optional[DatasetExportInfo] = Field( |
| default=None, |
| description="Dataset export information (if export enabled)" |
| ) |
| cost_info: Optional[CostInfo] = Field( |
| default=None, |
| description="Cost information for this document (Research Parity)" |
| ) |
|
|
|
|
| class GenerateDocumentResponse(BaseModel): |
| """Response schema for document generation endpoint.""" |
| success: bool = Field(description="Whether generation was successful") |
| message: str = Field(description="Status message") |
| documents: List[DocumentResult] = Field( |
| default_factory=list, |
| description="List of generated documents" |
| ) |
| total_documents: int = Field( |
| default=0, |
| description="Total number of documents generated" |
| ) |
| total_cost: Optional[CostInfo] = Field( |
| default=None, |
| description="Aggregated cost for the entire request" |
| ) |
|
|
|
|
| class HealthResponse(BaseModel): |
| """Health check response.""" |
| status: str = Field(default="healthy") |
| version: str = Field(default="1.0.0") |
|
|