File size: 16,270 Bytes
dc4e6da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 | """
Pydantic schemas for API request/response models.
"""
from typing import List, Optional
from pydantic import BaseModel, HttpUrl, Field, field_validator
class PromptParameters(BaseModel):
"""Parameters for customizing the document generation prompt."""
language: str = Field(
default="English",
description="Language for generated documents"
)
doc_type: str = Field(
default="business and administrative",
description="Type of documents to generate (e.g., 'business and administrative', 'receipts', 'forms')"
)
gt_type: str = Field(
default="Multiple questions about each document, with their answers taken **verbatim** from the document.",
description="Description of ground truth type to generate"
)
gt_format: str = Field(
default='{"<Text of question 1>": "<Answer to question 1>", "<Text of question 2>": "<Answer to question 2>", ...}',
description="Format specification for ground truth JSON"
)
num_solutions: int = Field(
default=1,
ge=1,
le=5,
description="Number of document variations to generate (1-5)"
)
# Stage 3: Feature Synthesis parameters
enable_handwriting: bool = Field(
default=False,
description="Enable handwriting generation (requires EC2 handwriting service)"
)
handwriting_ratio: float = Field(
default=0.2,
ge=0.0,
le=1.0,
description="Proportion of text to convert to handwriting (0.0-1.0)"
)
handwriting_apply_ink_filter: bool = Field(
default=True,
description="Apply high-contrast ink filter to handwriting (v16+ feature)"
)
handwriting_enable_enhancements: bool = Field(
default=False,
description="Enable sharpening and contrast boosting (Experimental)"
)
handwriting_num_inference_steps: int = Field(
default=1000,
ge=1,
le=1000,
description="Number of diffusion inference steps (1-1000)"
)
handwriting_writer_ids: List[int] = Field(
default=[404, 347, 156, 253, 354, 166, 320],
description="List of writer style IDs to use for handwriting generation"
)
enable_visual_elements: bool = Field(
default=True,
description="Enable visual element generation (stamps, logos, barcodes)"
)
visual_element_types: List[str] = Field(
default=["stamp", "logo", "figure", "barcode", "photo"],
description="Types of visual elements to generate (stamp, logo, figure, barcode, photo)"
)
barcode_number: Optional[str] = Field(
default=None,
description="Optional fixed number for barcode generation (numeric only)"
)
seed: Optional[int] = Field(
default=None,
description="Random seed for reproducible generation",
examples=[None, 42]
)
# Stage 4: Image Finalization & OCR parameters
enable_ocr: bool = Field(
default=True,
description="Enable OCR on final document images (requires OCR service)"
)
ocr_language: str = Field(
default="en",
description="Language for OCR (e.g., 'en', 'de', 'fr')"
)
# Stage 5: Dataset Packaging parameters
enable_bbox_normalization: bool = Field(
default=True,
description="Normalize bounding boxes to [0,1] scale (Stage 16)"
)
enable_gt_verification: bool = Field(
default=True,
description="Verify and prepare ground truth annotations (Stage 17)"
)
enable_analysis: bool = Field(
default=True,
description="Generate dataset statistics and analysis (Stage 18)"
)
enable_debug_visualization: bool = Field(
default=True,
description="Create debug visualization overlays (Stage 19)"
)
enable_dataset_export: bool = Field(
default=True,
description="Export as msgpack dataset format"
)
dataset_export_format: str = Field(
default="msgpack",
description="Dataset export format: 'msgpack', 'coco', 'huggingface'"
)
output_detail: str = Field(
default="dataset",
description="Output detail level: 'minimal' (final outputs only), 'dataset' (includes individual tokens/elements for ML), 'complete' (all intermediate files and debug info). Warning: 'complete' mode can produce 50+ MB responses."
)
class SeedImage(BaseModel):
"""Seed image URL for document generation."""
url: HttpUrl = Field(
description="URL of the seed image",
default=HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp")
)
class GenerateDocumentRequest(BaseModel):
"""Request schema for document generation endpoint."""
request_id: str = Field(
description="Document request UUID from document_requests table (created by frontend)"
)
google_drive_token: Optional[str] = Field(
default=None,
description="Google Drive OAuth access token. Frontend provides this after OAuth flow (optional)."
)
google_drive_refresh_token: Optional[str] = Field(
default=None,
description="Google Drive refresh token (optional, for automatic token renewal)"
)
seed_images: List[HttpUrl] = Field(
default=[HttpUrl("https://ocr.space/Content/Images/receipt-ocr-original.webp")],
description="List of seed image URLs (1-10 images)"
)
prompt_params: PromptParameters = Field(
default_factory=PromptParameters,
description="Parameters for customizing the generation prompt"
)
@field_validator('seed_images')
@classmethod
def validate_seed_images(cls, v):
if not v:
raise ValueError('At least one seed image is required')
if len(v) < 1:
raise ValueError('At least one seed image is required')
if len(v) > 10:
raise ValueError('Maximum 10 seed images allowed')
return v
class OCRWord(BaseModel):
"""OCR word-level result."""
text: str = Field(description="Recognized text")
confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)")
x: float = Field(description="X coordinate (pixels)")
y: float = Field(description="Y coordinate (pixels)")
width: float = Field(description="Width (pixels)")
height: float = Field(description="Height (pixels)")
class OCRLine(BaseModel):
"""OCR line-level result."""
text: str = Field(description="Recognized text")
confidence: float = Field(ge=0.0, le=1.0, description="OCR confidence score (0-1)")
x: float = Field(description="X coordinate (pixels)")
y: float = Field(description="Y coordinate (pixels)")
width: float = Field(description="Width (pixels)")
height: float = Field(description="Height (pixels)")
words: List[OCRWord] = Field(default_factory=list, description="Words in this line")
class OCRResult(BaseModel):
"""OCR results for a document."""
image_width: int = Field(description="Image width in pixels")
image_height: int = Field(description="Image height in pixels")
words: List[OCRWord] = Field(default_factory=list, description="Word-level OCR results")
lines: List[OCRLine] = Field(default_factory=list, description="Line-level OCR results")
angle: float = Field(default=0.0, description="Detected text orientation angle")
class CostInfo(BaseModel):
"""Cost information for a request (Research Parity)."""
input_tokens: int = Field(description="Number of input tokens")
output_tokens: int = Field(description="Number of output tokens")
cache_creation_tokens: int = Field(default=0, description="Tokens used for cache creation")
cache_read_tokens: int = Field(default=0, description="Tokens read from cache")
cost_usd: float = Field(description="Total cost in USD (with 50% batch discount applied if applicable)")
batch_discount_applied: bool = Field(default=False, description="Whether 50% batch discount was applied")
class NormalizedBBox(BaseModel):
"""Normalized bounding box (Stage 16)."""
text: str = Field(description="Text content")
x0: float = Field(ge=0.0, le=1.0, description="Normalized X min (0-1)")
y0: float = Field(ge=0.0, le=1.0, description="Normalized Y min (0-1)")
x2: float = Field(ge=0.0, le=1.0, description="Normalized X max (0-1)")
y2: float = Field(ge=0.0, le=1.0, description="Normalized Y max (0-1)")
block_no: Optional[int] = Field(default=None, description="Block number")
line_no: Optional[int] = Field(default=None, description="Line number")
word_no: Optional[int] = Field(default=None, description="Word number")
class GTVerificationResult(BaseModel):
"""Ground truth verification results (Stage 17)."""
passed: bool = Field(description="Whether GT verification passed")
skipped: bool = Field(default=False, description="Whether verification was skipped")
confirmed_keys: List[str] = Field(default_factory=list, description="Confirmed GT keys")
similarities: List[float] = Field(default_factory=list, description="Similarity scores")
num_layout_elements: Optional[int] = Field(default=None, description="Number of layout elements")
valid_labels: bool = Field(default=True, description="Whether all labels are valid")
class AnalysisStats(BaseModel):
"""Dataset analysis and statistics (Stage 18)."""
total_documents: int = Field(description="Total documents processed")
valid_documents: int = Field(description="Documents passing all validation")
error_counts: dict = Field(default_factory=dict, description="Error type counts")
has_handwriting: int = Field(default=0, description="Documents with handwriting")
has_visual_elements: int = Field(default=0, description="Documents with visual elements")
has_ocr: int = Field(default=0, description="Documents with OCR results")
multipage_count: int = Field(default=0, description="Multipage documents")
token_usage: Optional[dict] = Field(default=None, description="LLM token usage statistics")
class DebugVisualization(BaseModel):
"""Debug visualization data (Stage 19)."""
bbox_overlay_base64: Optional[str] = Field(default=None, description="Image with bbox overlays (PNG base64)")
visual_elements_overlay_base64: Optional[str] = Field(default=None, description="Image with visual element overlays")
handwriting_overlay_base64: Optional[str] = Field(default=None, description="Image with handwriting overlays")
class DatasetExportInfo(BaseModel):
"""Dataset export metadata."""
format: str = Field(description="Export format (msgpack, coco, etc.)")
num_samples: int = Field(description="Number of samples in export")
output_path: Optional[str] = Field(default=None, description="Path to exported dataset")
msgpack_base64: Optional[str] = Field(default=None, description="Msgpack file as base64 (for small datasets)")
metadata: dict = Field(default_factory=dict, description="Dataset metadata")
class BoundingBox(BaseModel):
"""Bounding box for a text element in the document."""
text: str = Field(description="Text content")
x: float = Field(description="X coordinate (normalized 0-1)")
y: float = Field(description="Y coordinate (normalized 0-1)")
width: float = Field(description="Width (normalized 0-1)")
height: float = Field(description="Height (normalized 0-1)")
page: int = Field(default=0, description="Page number (0-indexed)")
class HandwritingRegion(BaseModel):
"""Information about a handwriting region in the document."""
region_id: str = Field(description="Unique region identifier")
text: str = Field(description="Text content")
author_id: int = Field(ge=0, le=656, description="Author ID for style consistency (0-656)")
bbox: BoundingBox = Field(description="Bounding box of the region")
class VisualElement(BaseModel):
"""Information about a visual element in the document."""
element_id: str = Field(description="Unique element identifier")
element_type: str = Field(description="Type of visual element (stamp, logo, etc.)")
content: Optional[str] = Field(default=None, description="Content (e.g., stamp text)")
bbox: BoundingBox = Field(description="Bounding box of the element")
class DocumentResult(BaseModel):
"""Result for a single generated document."""
document_id: str = Field(description="Unique document identifier")
html: str = Field(description="Generated HTML content")
css: str = Field(description="Extracted CSS styles")
ground_truth: Optional[dict] = Field(
default=None,
description="Ground truth data extracted from the document"
)
pdf_base64: str = Field(description="Base64-encoded PDF document")
bboxes: List[BoundingBox] = Field(
default_factory=list,
description="Bounding boxes for text elements"
)
page_width_mm: float = Field(description="Page width in millimeters")
page_height_mm: float = Field(description="Page height in millimeters")
# Stage 3 additions
handwriting_regions: Optional[List[dict]] = Field(
default=None,
description="Handwriting regions with metadata (if enabled)"
)
visual_elements: Optional[List[dict]] = Field(
default=None,
description="Visual elements with metadata (if enabled)"
)
image_base64: Optional[str] = Field(
default=None,
description="Final rendered image with handwriting/visuals (PNG base64, if Stage 3 enabled)"
)
# Stage 3 individual tokens (dataset/complete output detail levels)
handwriting_token_images: Optional[dict] = Field(
default=None,
description="Individual handwriting token images {hw_id: base64_png} (output_detail: dataset/complete)"
)
visual_element_images: Optional[dict] = Field(
default=None,
description="Individual visual element images {ve_id: base64_png} (output_detail: dataset/complete)"
)
token_mapping: Optional[dict] = Field(
default=None,
description="Token mapping with positions and style IDs (output_detail: dataset/complete)"
)
# Stage 4 additions
ocr_results: Optional[OCRResult] = Field(
default=None,
description="OCR results from final image (if OCR enabled)"
)
# Stage 5 additions
normalized_bboxes_word: Optional[List[NormalizedBBox]] = Field(
default=None,
description="Word-level normalized bounding boxes (if Stage 16 enabled)"
)
normalized_bboxes_segment: Optional[List[NormalizedBBox]] = Field(
default=None,
description="Segment-level normalized bounding boxes (if Stage 16 enabled)"
)
gt_verification: Optional[GTVerificationResult] = Field(
default=None,
description="Ground truth verification results (if Stage 17 enabled)"
)
analysis_stats: Optional[AnalysisStats] = Field(
default=None,
description="Document analysis statistics (if Stage 18 enabled)"
)
debug_visualization: Optional[DebugVisualization] = Field(
default=None,
description="Debug visualization overlays (if Stage 19 enabled)"
)
dataset_export: Optional[DatasetExportInfo] = Field(
default=None,
description="Dataset export information (if export enabled)"
)
cost_info: Optional[CostInfo] = Field(
default=None,
description="Cost information for this document (Research Parity)"
)
class GenerateDocumentResponse(BaseModel):
"""Response schema for document generation endpoint."""
success: bool = Field(description="Whether generation was successful")
message: str = Field(description="Status message")
documents: List[DocumentResult] = Field(
default_factory=list,
description="List of generated documents"
)
total_documents: int = Field(
default=0,
description="Total number of documents generated"
)
total_cost: Optional[CostInfo] = Field(
default=None,
description="Aggregated cost for the entire request"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(default="healthy")
version: str = Field(default="1.0.0")
|