| | """ |
| | Vision-Language Model Interface |
| | |
| | Abstract interface for multimodal models that understand both |
| | images and text. Used for document understanding, VQA, and |
| | complex reasoning over visual content. |
| | """ |
| |
|
| | from abc import abstractmethod |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | from ..chunks.models import BoundingBox |
| | from .base import ( |
| | BaseModel, |
| | BatchableModel, |
| | ImageInput, |
| | ModelCapability, |
| | ModelConfig, |
| | ) |
| |
|
| |
|
| | class VLMTask(str, Enum): |
| | """Tasks that VLM models can perform.""" |
| |
|
| | |
| | DOCUMENT_QA = "document_qa" |
| | DOCUMENT_SUMMARY = "document_summary" |
| | DOCUMENT_CLASSIFICATION = "document_classification" |
| |
|
| | |
| | IMAGE_CAPTION = "image_caption" |
| | IMAGE_QA = "image_qa" |
| | VISUAL_GROUNDING = "visual_grounding" |
| |
|
| | |
| | FIELD_EXTRACTION = "field_extraction" |
| | TABLE_UNDERSTANDING = "table_understanding" |
| | CHART_UNDERSTANDING = "chart_understanding" |
| |
|
| | |
| | OCR_CORRECTION = "ocr_correction" |
| | TEXT_GENERATION = "text_generation" |
| |
|
| | |
| | GENERAL = "general" |
| |
|
| |
|
| | @dataclass |
| | class VLMConfig(ModelConfig): |
| | """Configuration for vision-language models.""" |
| |
|
| | max_tokens: int = 2048 |
| | temperature: float = 0.1 |
| | top_p: float = 0.9 |
| | max_image_size: int = 1024 |
| | image_detail: str = "high" |
| | system_prompt: Optional[str] = None |
| |
|
| | def __post_init__(self): |
| | super().__post_init__() |
| | if not self.name: |
| | self.name = "vlm" |
| |
|
| |
|
| | @dataclass |
| | class VLMMessage: |
| | """A message in a VLM conversation.""" |
| |
|
| | role: str |
| | content: str |
| | images: List[ImageInput] = field(default_factory=list) |
| | image_regions: List[Optional[BoundingBox]] = field(default_factory=list) |
| |
|
| |
|
| | @dataclass |
| | class VLMResponse: |
| | """Response from a VLM model.""" |
| |
|
| | text: str |
| | confidence: float = 0.0 |
| | tokens_used: int = 0 |
| | finish_reason: str = "stop" |
| |
|
| | |
| | grounded_regions: List[BoundingBox] = field(default_factory=list) |
| | region_labels: List[str] = field(default_factory=list) |
| |
|
| | |
| | structured_data: Optional[Dict[str, Any]] = None |
| |
|
| | |
| | processing_time_ms: float = 0.0 |
| | model_metadata: Dict[str, Any] = field(default_factory=dict) |
| |
|
| |
|
| | @dataclass |
| | class DocumentQAResult: |
| | """Result of document question answering.""" |
| |
|
| | question: str |
| | answer: str |
| | confidence: float = 0.0 |
| |
|
| | |
| | evidence_regions: List[BoundingBox] = field(default_factory=list) |
| | evidence_text: List[str] = field(default_factory=list) |
| | page_references: List[int] = field(default_factory=list) |
| |
|
| | |
| | abstained: bool = False |
| | abstention_reason: Optional[str] = None |
| |
|
| |
|
| | @dataclass |
| | class FieldExtractionVLMResult: |
| | """Result of field extraction using VLM.""" |
| |
|
| | fields: Dict[str, Any] = field(default_factory=dict) |
| | confidence_scores: Dict[str, float] = field(default_factory=dict) |
| |
|
| | |
| | field_regions: Dict[str, BoundingBox] = field(default_factory=dict) |
| | field_evidence: Dict[str, str] = field(default_factory=dict) |
| |
|
| | |
| | abstained_fields: List[str] = field(default_factory=list) |
| | abstention_reasons: Dict[str, str] = field(default_factory=dict) |
| |
|
| | overall_confidence: float = 0.0 |
| |
|
| |
|
| | class VisionLanguageModel(BatchableModel): |
| | """ |
| | Abstract base class for Vision-Language Models. |
| | |
| | These models combine visual understanding with language |
| | capabilities for tasks like document QA, field extraction, |
| | and visual reasoning. |
| | """ |
| |
|
| | def __init__(self, config: Optional[VLMConfig] = None): |
| | super().__init__(config or VLMConfig(name="vlm")) |
| | self.config: VLMConfig = self.config |
| |
|
| | def get_capabilities(self) -> List[ModelCapability]: |
| | return [ModelCapability.VISION_LANGUAGE] |
| |
|
| | @abstractmethod |
| | def generate( |
| | self, |
| | prompt: str, |
| | images: List[ImageInput], |
| | **kwargs |
| | ) -> VLMResponse: |
| | """ |
| | Generate a response given text prompt and images. |
| | |
| | Args: |
| | prompt: Text prompt/question |
| | images: List of images for context |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | VLMResponse with generated text |
| | """ |
| | pass |
| |
|
| | def process_batch( |
| | self, |
| | inputs: List[Tuple[str, List[ImageInput]]], |
| | **kwargs |
| | ) -> List[VLMResponse]: |
| | """ |
| | Process multiple prompt-image pairs. |
| | |
| | Args: |
| | inputs: List of (prompt, images) tuples |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | List of VLMResponses |
| | """ |
| | return [ |
| | self.generate(prompt, images, **kwargs) |
| | for prompt, images in inputs |
| | ] |
| |
|
| | @abstractmethod |
| | def chat( |
| | self, |
| | messages: List[VLMMessage], |
| | **kwargs |
| | ) -> VLMResponse: |
| | """ |
| | Multi-turn conversation with images. |
| | |
| | Args: |
| | messages: Conversation history |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | VLMResponse for the conversation |
| | """ |
| | pass |
| |
|
| | def answer_question( |
| | self, |
| | question: str, |
| | document_images: List[ImageInput], |
| | context: Optional[str] = None, |
| | **kwargs |
| | ) -> DocumentQAResult: |
| | """ |
| | Answer a question about document images. |
| | |
| | Args: |
| | question: Question to answer |
| | document_images: Document page images |
| | context: Optional additional context |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | DocumentQAResult with answer and evidence |
| | """ |
| | prompt = self._build_qa_prompt(question, context) |
| | response = self.generate(prompt, document_images, **kwargs) |
| |
|
| | |
| | answer, confidence, abstained, reason = self._parse_qa_response(response.text) |
| |
|
| | return DocumentQAResult( |
| | question=question, |
| | answer=answer, |
| | confidence=confidence, |
| | evidence_regions=response.grounded_regions, |
| | abstained=abstained, |
| | abstention_reason=reason |
| | ) |
| |
|
| | def extract_fields( |
| | self, |
| | images: List[ImageInput], |
| | schema: Dict[str, Any], |
| | **kwargs |
| | ) -> FieldExtractionVLMResult: |
| | """ |
| | Extract fields from document images according to a schema. |
| | |
| | Args: |
| | images: Document page images |
| | schema: Field schema (JSON Schema or Pydantic-like) |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | FieldExtractionVLMResult with extracted values |
| | """ |
| | prompt = self._build_extraction_prompt(schema) |
| | response = self.generate(prompt, images, **kwargs) |
| |
|
| | |
| | result = self._parse_extraction_response(response, schema) |
| | return result |
| |
|
| | def summarize_document( |
| | self, |
| | images: List[ImageInput], |
| | max_length: int = 500, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Generate a summary of document images. |
| | |
| | Args: |
| | images: Document page images |
| | max_length: Maximum summary length |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | Document summary text |
| | """ |
| | prompt = f"""Summarize this document in at most {max_length} characters. |
| | Focus on the main points and key information. |
| | Be concise and factual.""" |
| |
|
| | response = self.generate(prompt, images, **kwargs) |
| | return response.text |
| |
|
| | def classify_document( |
| | self, |
| | images: List[ImageInput], |
| | categories: List[str], |
| | **kwargs |
| | ) -> Tuple[str, float]: |
| | """ |
| | Classify document into predefined categories. |
| | |
| | Args: |
| | images: Document page images |
| | categories: List of possible categories |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | Tuple of (category, confidence) |
| | """ |
| | categories_str = ", ".join(categories) |
| | prompt = f"""Classify this document into one of these categories: {categories_str} |
| | |
| | Respond with just the category name and confidence (0-1). |
| | Format: CATEGORY: confidence |
| | |
| | If you cannot confidently classify, respond with: UNKNOWN: 0.0""" |
| |
|
| | response = self.generate(prompt, images, **kwargs) |
| |
|
| | |
| | try: |
| | parts = response.text.strip().split(":") |
| | category = parts[0].strip().upper() |
| | confidence = float(parts[1].strip()) if len(parts) > 1 else 0.5 |
| |
|
| | |
| | category_upper = {c.upper(): c for c in categories} |
| | if category in category_upper: |
| | return category_upper[category], confidence |
| | return "UNKNOWN", 0.0 |
| | except Exception: |
| | return "UNKNOWN", 0.0 |
| |
|
| | def _build_qa_prompt( |
| | self, |
| | question: str, |
| | context: Optional[str] = None |
| | ) -> str: |
| | """Build prompt for document QA.""" |
| | prompt_parts = [ |
| | "You are analyzing a document image. Answer the following question based only on what you can see in the document.", |
| | "", |
| | "IMPORTANT RULES:", |
| | "- Only use information visible in the document", |
| | "- If the answer is not found, say 'NOT FOUND' and explain why", |
| | "- Be precise and quote exact values when possible", |
| | "- Indicate your confidence level (HIGH, MEDIUM, LOW)", |
| | "" |
| | ] |
| |
|
| | if context: |
| | prompt_parts.extend([ |
| | "Additional context:", |
| | context, |
| | "" |
| | ]) |
| |
|
| | prompt_parts.extend([ |
| | f"Question: {question}", |
| | "", |
| | "Provide your answer in this format:", |
| | "ANSWER: [your answer]", |
| | "CONFIDENCE: [HIGH/MEDIUM/LOW]", |
| | "EVIDENCE: [quote or describe where you found this information]" |
| | ]) |
| |
|
| | return "\n".join(prompt_parts) |
| |
|
| | def _parse_qa_response( |
| | self, |
| | response_text: str |
| | ) -> Tuple[str, float, bool, Optional[str]]: |
| | """Parse QA response for answer, confidence, and abstention.""" |
| | lines = response_text.strip().split("\n") |
| |
|
| | answer = "" |
| | confidence = 0.5 |
| | abstained = False |
| | reason = None |
| |
|
| | for line in lines: |
| | line_lower = line.lower() |
| | if line_lower.startswith("answer:"): |
| | answer = line.split(":", 1)[1].strip() |
| | elif line_lower.startswith("confidence:"): |
| | conf_str = line.split(":", 1)[1].strip().upper() |
| | confidence = {"HIGH": 0.9, "MEDIUM": 0.6, "LOW": 0.3}.get(conf_str, 0.5) |
| |
|
| | |
| | if "not found" in answer.lower() or "cannot find" in answer.lower(): |
| | abstained = True |
| | reason = answer |
| |
|
| | return answer, confidence, abstained, reason |
| |
|
| | def _build_extraction_prompt(self, schema: Dict[str, Any]) -> str: |
| | """Build prompt for field extraction.""" |
| | import json |
| |
|
| | schema_str = json.dumps(schema, indent=2) |
| |
|
| | prompt = f"""Extract the following fields from this document image. |
| | |
| | SCHEMA: |
| | {schema_str} |
| | |
| | RULES: |
| | - Only extract values that are clearly visible in the document |
| | - For each field, provide the exact value and its location |
| | - If a field is not found, mark it as null with confidence 0 |
| | - Be precise with numbers, dates, and proper nouns |
| | |
| | Respond in valid JSON format matching the schema. |
| | Include a "_confidence" object with confidence scores (0-1) for each field. |
| | Include a "_evidence" object with the text snippet where each value was found. |
| | """ |
| | return prompt |
| |
|
| | def _parse_extraction_response( |
| | self, |
| | response: VLMResponse, |
| | schema: Dict[str, Any] |
| | ) -> FieldExtractionVLMResult: |
| | """Parse extraction response into structured result.""" |
| | import json |
| |
|
| | result = FieldExtractionVLMResult() |
| |
|
| | try: |
| | |
| | text = response.text.strip() |
| |
|
| | |
| | if "```json" in text: |
| | start = text.find("```json") + 7 |
| | end = text.find("```", start) |
| | text = text[start:end].strip() |
| | elif "```" in text: |
| | start = text.find("```") + 3 |
| | end = text.find("```", start) |
| | text = text[start:end].strip() |
| |
|
| | data = json.loads(text) |
| |
|
| | |
| | for key, value in data.items(): |
| | if key.startswith("_"): |
| | continue |
| | result.fields[key] = value |
| |
|
| | |
| | if "_confidence" in data: |
| | result.confidence_scores = data["_confidence"] |
| |
|
| | |
| | if "_evidence" in data: |
| | result.field_evidence = data["_evidence"] |
| |
|
| | |
| | for field_name in schema.get("properties", {}).keys(): |
| | if field_name not in result.fields or result.fields[field_name] is None: |
| | result.abstained_fields.append(field_name) |
| | result.abstention_reasons[field_name] = "Field not found in document" |
| |
|
| | |
| | if result.confidence_scores: |
| | result.overall_confidence = sum(result.confidence_scores.values()) / len(result.confidence_scores) |
| |
|
| | except json.JSONDecodeError: |
| | |
| | for field_name in schema.get("properties", {}).keys(): |
| | result.abstained_fields.append(field_name) |
| | result.abstention_reasons[field_name] = "Failed to parse extraction response" |
| |
|
| | return result |
| |
|