Spaces:
Running
Running
| import base64 | |
| import enum | |
| import re | |
| import urllib.parse | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Dict, Any | |
| from app.core.ocr import OCRProcessor | |
| class InputType(enum.Enum): | |
| """Enumeration of supported input types.""" | |
| TEXT = "text" | |
| LATEX = "latex" | |
| IMAGE_URL = "image_url" | |
| BASE64_IMAGE = "base64_image" | |
| MULTIMODAL = "multimodal" | |
| UNKNOWN = "unknown" | |
| # ... (omitted dataclass, no changes needed there) ... | |
| class ProcessingResult: | |
| """Result of the input processing pipeline.""" | |
| input_type: InputType | |
| cleaned_content: str | |
| is_valid: bool | |
| error_message: Optional[str] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| class InputProcessor: | |
| """ | |
| Handles detection, normalization, and validation of user inputs. | |
| Attributes: | |
| max_length (int): Maximum allowed characters for text inputs. | |
| """ | |
| def __init__(self, max_length: int = 5000): | |
| """ | |
| Initialize the InputProcessor. | |
| Args: | |
| max_length: Maximum allowed length for input strings. Defaults to 5000. | |
| """ | |
| self.max_length = max_length | |
| # Basic SQL injection and script tag patterns | |
| self._dangerous_patterns = [ | |
| re.compile(r"<script.*?>.*?</script>", re.IGNORECASE | re.DOTALL), | |
| re.compile(r"javascript:", re.IGNORECASE), | |
| re.compile(r"union\s+select", re.IGNORECASE), | |
| re.compile(r"drop\s+table", re.IGNORECASE), | |
| re.compile(r"exec\s*\(", re.IGNORECASE), | |
| ] | |
| self.ocr_processor = OCRProcessor() | |
| def process_compound(self, text_input: Optional[str] = None, image_input: Optional[str] = None) -> ProcessingResult: | |
| """ | |
| Process combined text and image input. | |
| Args: | |
| text_input: Optional text query. | |
| image_input: Optional image (Base64 or URL). | |
| Returns: | |
| ProcessingResult: Combined result. | |
| """ | |
| cleaned_text = "" | |
| image_data = None | |
| detected_type = InputType.UNKNOWN | |
| error_msg = None | |
| # 1. Process Image if present | |
| if image_input: | |
| # Detect if URL or Base64 (naive check) | |
| if image_input.startswith("http") and "://" in image_input: | |
| # URL -> Download | |
| image_data = self.ocr_processor.download_image_as_base64(image_input) | |
| if not image_data: | |
| return ProcessingResult(InputType.IMAGE_URL, "", False, "Failed to download image.") | |
| detected_type = InputType.IMAGE_URL # Or promote to MULTIMODAL later | |
| else: | |
| # Assume Base64 | |
| # Strip prefix if needed | |
| if ";base64," in image_input: | |
| _, raw_b64 = image_input.split(";base64,") | |
| else: | |
| raw_b64 = image_input.strip() | |
| # Basic validation? | |
| if len(raw_b64) < 10: | |
| return ProcessingResult(InputType.BASE64_IMAGE, "", False, "Invalid image data.") | |
| image_data = raw_b64 | |
| detected_type = InputType.BASE64_IMAGE | |
| # 2. Process Text if present | |
| if text_input: | |
| cleaned_text = self._normalize_text(text_input) | |
| # If we also have an image, it's MULTIMODAL. | |
| # CRITICAL: We MUST preserve the text input as it provides specific context (e.g. "Solve part b") | |
| if image_data: | |
| detected_type = InputType.MULTIMODAL | |
| elif detected_type == InputType.UNKNOWN: | |
| # Text only, refined detection (latex vs text) | |
| detected_type = self._detect_type(cleaned_text) | |
| # 3. Final Validation | |
| if not cleaned_text and not image_data: | |
| return ProcessingResult(InputType.UNKNOWN, "", False, "No valid input provided.") | |
| metadata = {} | |
| if image_data: | |
| metadata["image_data"] = image_data | |
| # Validate text content if present (length, safety) | |
| if cleaned_text: | |
| is_valid, err = self._validate(cleaned_text, detected_type) | |
| if not is_valid: | |
| return ProcessingResult(detected_type, cleaned_text, False, err) | |
| return ProcessingResult( | |
| input_type=detected_type, | |
| cleaned_content=cleaned_text, | |
| is_valid=True, | |
| metadata=metadata | |
| ) | |
| def process(self, input_data: str) -> ProcessingResult: | |
| """ | |
| Process the raw input string: detect type, normalize, and validate. | |
| Args: | |
| input_data: The raw input string from the user. | |
| Returns: | |
| ProcessingResult: The processed and validated result. | |
| """ | |
| if not input_data: | |
| return ProcessingResult(InputType.UNKNOWN, "", False, "Input cannot be empty.") | |
| metadata = None | |
| detected_type = self._detect_type(input_data) | |
| if detected_type in (InputType.TEXT, InputType.LATEX): | |
| cleaned_content = self._normalize_text(input_data) | |
| elif detected_type == InputType.BASE64_IMAGE: | |
| # Process Base64 | |
| # Store raw base64 for Vision Model (strip prefix) | |
| try: | |
| if ";base64," in input_data: | |
| _, raw_b64 = input_data.split(";base64,") | |
| else: | |
| raw_b64 = input_data | |
| except ValueError: | |
| return ProcessingResult(detected_type, "", False, "Invalid base64 image format.") | |
| # Skip OCR for Base64 images to avoid "double work" and latency. | |
| # We rely on Gemini Vision to read the image directly. | |
| # cleaned_content is set to empty string; hashing will rely on image_data hash. | |
| extracted_text = "" | |
| # Attach image data to result | |
| # Optimize image (resize/compress) to reduce token count and bandwidth | |
| optimized_b64 = self.ocr_processor.optimize_base64(raw_b64) | |
| metadata = {"image_data": optimized_b64} | |
| cleaned_content = "" | |
| elif detected_type == InputType.IMAGE_URL: | |
| # Process URL: Download and pass as image to Vision (Skip OCR) | |
| # extracted_text = self.ocr_processor.process_url(input_data) | |
| raw_b64 = self.ocr_processor.download_image_as_base64(input_data) | |
| if not raw_b64: | |
| return ProcessingResult(detected_type, "", False, "Failed to download image from URL.") | |
| # Attach image data (Vision will use this) | |
| metadata = {"image_data": raw_b64} | |
| cleaned_content = "" | |
| else: | |
| cleaned_content = input_data.strip() | |
| metadata = None | |
| is_valid, error_msg = self._validate(cleaned_content, detected_type) | |
| return ProcessingResult( | |
| input_type=detected_type, | |
| cleaned_content=cleaned_content, | |
| is_valid=is_valid, | |
| error_message=error_msg, | |
| metadata=metadata | |
| ) | |
| def _detect_type(self, data: str) -> InputType: | |
| """ | |
| Detect the type of the input data. | |
| Args: | |
| data: The raw input string. | |
| Returns: | |
| InputType: The detected input type. | |
| """ | |
| data = data.strip() | |
| # Check for Base64 Image | |
| # A simple heuristic: starts with data:image/ or looks like base64 | |
| if data.startswith("data:image/") and ";base64," in data: | |
| return InputType.BASE64_IMAGE | |
| # Check for Image URL | |
| parsed_url = urllib.parse.urlparse(data) | |
| if parsed_url.scheme in ("http", "https") and any( | |
| data.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"] | |
| ): | |
| return InputType.IMAGE_URL | |
| # Check for LaTeX | |
| # Heuristic: contains mathematical delimiters or keywords | |
| if ( | |
| "$" in data | |
| or "\\[" in data | |
| or "\\(" in data | |
| or re.search(r"\\[a-zA-Z]+", data) | |
| ): | |
| return InputType.LATEX | |
| # Default to text | |
| return InputType.TEXT | |
| def _normalize_text(self, text: str) -> str: | |
| """ | |
| Normalize text input: lowercase, trim, remove extra spaces. | |
| Args: | |
| text: The text to normalize. | |
| Returns: | |
| str: Normalized text. | |
| """ | |
| # Lowercase | |
| text = text.lower() | |
| # Remove extra horizontal whitespace (tabs, multiple spaces) | |
| text = re.sub(r'[ \t]+', ' ', text) | |
| # Collapse multiple newlines into one | |
| text = re.sub(r'\n+', '\n', text) | |
| return text.strip() | |
| def _remove_ocr_artifacts(self, text: str) -> str: | |
| """Remove common OCR extraction errors.""" | |
| # Remove repeated characters (OCR artifact) | |
| text = re.sub(r'([!?.\-_=])\1{3,}', r'\1\1', text) | |
| # Remove random special chars at start/end | |
| # Added '-' to the allowed list to prevent stripping negative numbers | |
| text = re.sub(r'^[^a-zA-Z0-9\(\[\{$\-]*', '', text) | |
| text = re.sub(r'[^a-zA-Z0-9\)\]\}\$]*$', '', text) | |
| return text | |
| def _validate(self, content: str, input_type: InputType) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate the content based on its type and generic safety rules. | |
| Args: | |
| content: The content to validate. | |
| input_type: The type of the content. | |
| Returns: | |
| Tuple[bool, Optional[str]]: (IsValid, ErrorMessage) | |
| """ | |
| if len(content) > self.max_length and input_type != InputType.BASE64_IMAGE: | |
| return False, f"Input length exceeds maximum limit of {self.max_length} characters." | |
| # For Base64, we allow larger size, but maybe still impose a limit? | |
| # For now, let's assume max_length applies to text/latex/url | |
| if input_type == InputType.BASE64_IMAGE: | |
| # Heuristic check for base64 validity | |
| if len(content) > 5_000_000: #5MB limit catch | |
| return False, "Image data too large." | |
| # Further base64 validation could be done here if needed | |
| return True, None | |
| if input_type == InputType.IMAGE_URL: | |
| # Basic URL validation done in detect and OCR, but check again for safety | |
| parsed = urllib.parse.urlparse(content) | |
| # Note: For IMAGE_URL, 'content' here is technically the ACTUALLY EXTRACTED TEXT now if we look at flow above? | |
| # Wait, logic check: | |
| # In `process`: | |
| # 1. detect_type -> IMAGE_URL | |
| # 2. if IMAGE_URL -> ocr -> extracted_text | |
| # 3. cleaned_content = extracted_text | |
| # 4. _validate(cleaned_content, IMAGE_URL) | |
| # So content passed to validate is TEXT. | |
| # BUT `detect_type` is IMAGE_URL. | |
| # So `_validate` logic for IMAGE_URL checking scheme is invalid b/c content is now text. | |
| # We should rely on OCR processor to have validated the image/url itself. | |
| # So here for IMAGE_URL/BASE64_IMAGE, we are validating the EXTRACTED text. | |
| pass | |
| # if parsed.scheme not in ('http', 'https'): | |
| # return False, "Invalid URL scheme." | |
| # return True, None | |
| # Check for dangerous payloads in text/latex | |
| for pattern in self._dangerous_patterns: | |
| if pattern.search(content): | |
| return False, "Input contains potentially dangerous content." | |
| return True, None | |