| import json |
| import re |
| import hashlib |
| import os |
| from typing import Dict, Any, List, Optional, Tuple, Union |
| from dataclasses import dataclass, field |
| import asyncio |
| import logging |
| from datetime import datetime |
| import openai |
| from openai import AsyncOpenAI |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| @dataclass |
| class ComplexityMetrics: |
| max_depth: int |
| total_fields: int |
| enum_count: int |
| required_fields: int |
| nested_objects: int |
| |
| @property |
| def complexity_tier(self) -> int: |
| if self.max_depth <= 2 and self.total_fields <= 20: |
| return 1 |
| elif self.max_depth <= 4 and self.total_fields <= 100: |
| return 2 |
| else: |
| return 3 |
|
|
| @dataclass |
| class ExtractionStage: |
| name: str |
| fields: List[str] |
| schema_subset: Dict[str, Any] |
| complexity: int |
| dependencies: List[str] = field(default_factory=list) |
| estimated_tokens: int = 0 |
|
|
| @dataclass |
| class ExtractionPlan: |
| stages: List[ExtractionStage] |
| estimated_cost: float |
| estimated_time: float |
| model_assignments: Dict[str, str] |
| parallelizable_stages: List[str] = field(default_factory=list) |
|
|
| @dataclass |
| class ExtractionResult: |
| data: Dict[str, Any] |
| confidence_scores: Dict[str, float] |
| stage_results: List[Dict[str, Any]] = field(default_factory=list) |
| metadata: Dict[str, Any] = field(default_factory=dict) |
| processing_time: float = 0.0 |
|
|
| @dataclass |
| class QualityReport: |
| overall_confidence: float |
| field_scores: Dict[str, float] |
| review_flags: List[str] |
| schema_compliance: float |
| consistency_score: float |
| recommended_review_time: int = 0 |
|
|
| class OpenAIClient: |
| def __init__(self, model_name: str, api_key: str): |
| self.model_name = model_name |
| self.client = AsyncOpenAI(api_key=api_key) |
| self.cost_per_token = { |
| "gpt-4o-mini": 0.00015, |
| "gpt-4o": 0.005, |
| "gpt-4-turbo": 0.003 |
| } |
| |
| async def complete(self, prompt: str, max_tokens: int = 4000) -> Tuple[str, float]: |
| try: |
| response = await self.client.chat.completions.create( |
| model=self.model_name, |
| messages=[ |
| {"role": "system", "content": "You are a precise data extraction specialist. Extract data according to the provided schema and output only valid JSON."}, |
| {"role": "user", "content": prompt} |
| ], |
| max_tokens=max_tokens, |
| temperature=0.1, |
| top_p=0.9 |
| ) |
| |
| content = response.choices[0].message.content |
| confidence = 0.9 if "gpt-4o" in self.model_name else 0.8 |
| |
| if content and len(content.strip()) > 10: |
| confidence += 0.05 |
| |
| return content, confidence |
| |
| except Exception as e: |
| logger.error(f"OpenAI API error: {e}") |
| return '{"error": "API call failed", "details": "' + str(e) + '"}', 0.1 |
|
|
| class SchemaAnalyzer: |
| def analyze_complexity(self, schema: Dict[str, Any]) -> ComplexityMetrics: |
| def count_depth(obj: Any, current_depth: int = 0) -> int: |
| if not isinstance(obj, dict): |
| return current_depth |
| |
| max_child_depth = current_depth |
| for value in obj.values(): |
| if isinstance(value, dict): |
| if 'properties' in value: |
| child_depth = count_depth(value['properties'], current_depth + 1) |
| else: |
| child_depth = count_depth(value, current_depth + 1) |
| max_child_depth = max(max_child_depth, child_depth) |
| return max_child_depth |
| |
| def count_fields(obj: Any) -> Tuple[int, int, int]: |
| if not isinstance(obj, dict): |
| return 0, 0, 0 |
| |
| total, enums, objects = 0, 0, 0 |
| |
| for key, value in obj.items(): |
| if key == 'properties' and isinstance(value, dict): |
| for prop_name, prop_def in value.items(): |
| total += 1 |
| if isinstance(prop_def, dict): |
| if 'enum' in prop_def: |
| enums += 1 |
| if prop_def.get('type') == 'object': |
| objects += 1 |
| nested_total, nested_enums, nested_objects = count_fields(prop_def) |
| total += nested_total |
| enums += nested_enums |
| objects += nested_objects |
| elif isinstance(value, dict): |
| nested_total, nested_enums, nested_objects = count_fields(value) |
| total += nested_total |
| enums += nested_enums |
| objects += nested_objects |
| |
| return total, enums, objects |
| |
| max_depth = count_depth(schema.get('properties', {})) |
| total_fields, enum_count, nested_objects = count_fields(schema) |
| required_fields = len(schema.get('required', [])) |
| |
| return ComplexityMetrics( |
| max_depth=max_depth, |
| total_fields=total_fields, |
| enum_count=enum_count, |
| required_fields=required_fields, |
| nested_objects=nested_objects |
| ) |
| |
| def create_extraction_plan(self, schema: Dict[str, Any], complexity: ComplexityMetrics) -> ExtractionPlan: |
| return self._create_single_pass_plan(schema) |
| |
| def _create_single_pass_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
| stages = [ExtractionStage( |
| name="complete_extraction", |
| fields=list(schema.get('properties', {}).keys()), |
| schema_subset=schema, |
| complexity=2, |
| estimated_tokens=4000 |
| )] |
| |
| return ExtractionPlan( |
| stages=stages, |
| estimated_cost=0.15, |
| estimated_time=15.0, |
| model_assignments={"complete_extraction": "gpt-4o"} |
| ) |
| |
| def _create_simple_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
| stages = [ExtractionStage( |
| name="complete_extraction", |
| fields=list(schema.get('properties', {}).keys()), |
| schema_subset=schema, |
| complexity=1, |
| estimated_tokens=2000 |
| )] |
| |
| return ExtractionPlan( |
| stages=stages, |
| estimated_cost=0.02, |
| estimated_time=5.0, |
| model_assignments={"complete_extraction": "gpt-4o"} |
| ) |
| |
| def _create_medium_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
| properties = schema.get('properties', {}) |
| simple_fields = [] |
| complex_fields = [] |
| |
| for field_name, field_def in properties.items(): |
| if isinstance(field_def, dict) and field_def.get('type') in ['object', 'array']: |
| complex_fields.append(field_name) |
| else: |
| simple_fields.append(field_name) |
| |
| stages = [] |
| if simple_fields: |
| stages.append(ExtractionStage( |
| name="simple_fields", |
| fields=simple_fields, |
| schema_subset=self._create_subset_schema(schema, simple_fields), |
| complexity=1, |
| estimated_tokens=1500 |
| )) |
| |
| if complex_fields: |
| stages.append(ExtractionStage( |
| name="complex_fields", |
| fields=complex_fields, |
| schema_subset=self._create_subset_schema(schema, complex_fields), |
| complexity=2, |
| dependencies=["simple_fields"] if simple_fields else [], |
| estimated_tokens=3000 |
| )) |
| |
| return ExtractionPlan( |
| stages=stages, |
| estimated_cost=0.15, |
| estimated_time=25.0, |
| model_assignments={ |
| "simple_fields": "gpt-4o-mini", |
| "complex_fields": "gpt-4o" |
| } |
| ) |
| |
| def _create_complex_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
| stages = self._create_hierarchical_stages(schema) |
| |
| model_assignments = { |
| stage.name: "gpt-4o" if stage.complexity > 1 else "gpt-4o-mini" |
| for stage in stages |
| } |
| |
| estimated_cost = len(stages) * 0.10 |
| estimated_time = len(stages) * 15.0 |
| |
| return ExtractionPlan( |
| stages=stages, |
| estimated_cost=min(estimated_cost, 2.0), |
| estimated_time=min(estimated_time, 120.0), |
| model_assignments=model_assignments |
| ) |
| |
| def _create_hierarchical_stages(self, schema: Dict[str, Any]) -> List[ExtractionStage]: |
| stages = [] |
| properties = schema.get('properties', {}) |
| |
| simple_fields = [ |
| field_name for field_name, field_def in properties.items() |
| if isinstance(field_def, dict) and field_def.get('type') in ['string', 'number', 'integer', 'boolean'] |
| and 'enum' not in field_def |
| ] |
| |
| if simple_fields: |
| stages.append(ExtractionStage( |
| name="primitive_fields", |
| fields=simple_fields, |
| schema_subset=self._create_subset_schema(schema, simple_fields), |
| complexity=1, |
| estimated_tokens=1000 |
| )) |
| |
| enum_fields = [ |
| field_name for field_name, field_def in properties.items() |
| if isinstance(field_def, dict) and 'enum' in field_def |
| ] |
| |
| if enum_fields: |
| stages.append(ExtractionStage( |
| name="enum_fields", |
| fields=enum_fields, |
| schema_subset=self._create_subset_schema(schema, enum_fields), |
| complexity=1, |
| dependencies=["primitive_fields"] if simple_fields else [], |
| estimated_tokens=1500 |
| )) |
| |
| array_fields = [ |
| field_name for field_name, field_def in properties.items() |
| if isinstance(field_def, dict) and field_def.get('type') == 'array' |
| ] |
| |
| if array_fields: |
| stages.append(ExtractionStage( |
| name="array_fields", |
| fields=array_fields, |
| schema_subset=self._create_subset_schema(schema, array_fields), |
| complexity=2, |
| dependencies=["primitive_fields", "enum_fields"], |
| estimated_tokens=2500 |
| )) |
| |
| object_fields = [ |
| field_name for field_name, field_def in properties.items() |
| if isinstance(field_def, dict) and field_def.get('type') == 'object' |
| ] |
| |
| if object_fields: |
| stages.append(ExtractionStage( |
| name="object_fields", |
| fields=object_fields, |
| schema_subset=self._create_subset_schema(schema, object_fields), |
| complexity=3, |
| dependencies=["primitive_fields", "enum_fields", "array_fields"], |
| estimated_tokens=4000 |
| )) |
| |
| return [stage for stage in stages if stage.fields] |
| |
| def _create_subset_schema(self, full_schema: Dict[str, Any], fields: List[str]) -> Dict[str, Any]: |
| properties = full_schema.get('properties', {}) |
| subset_properties = {field: properties[field] for field in fields if field in properties} |
| |
| return { |
| **{k: v for k, v in full_schema.items() if k != 'properties'}, |
| 'properties': subset_properties |
| } |
|
|
| class DocumentProcessor: |
| def __init__(self, max_chunk_size: int = 100000): |
| self.max_chunk_size = max_chunk_size |
| |
| def process_document(self, content: str, schema: Dict[str, Any]) -> List[str]: |
| if len(content) <= self.max_chunk_size: |
| return [content] |
| |
| logger.info(f"Document size {len(content)} exceeds chunk limit, creating semantic chunks") |
| return self._semantic_chunking(content, schema) |
| |
| def _semantic_chunking(self, content: str, schema: Dict[str, Any]) -> List[str]: |
| paragraphs = content.split('\n\n') |
| chunks = [] |
| current_chunk = "" |
| overlap_size = 1000 |
| |
| for para in paragraphs: |
| if len(current_chunk) + len(para) > self.max_chunk_size: |
| if current_chunk: |
| chunks.append(current_chunk) |
| current_chunk = current_chunk[-overlap_size:] + "\n\n" + para |
| else: |
| current_chunk = para |
| else: |
| current_chunk += "\n\n" + para if current_chunk else para |
| |
| if current_chunk: |
| chunks.append(current_chunk) |
| |
| logger.info(f"Created {len(chunks)} semantic chunks") |
| return chunks |
|
|
| class ExtractionEngine: |
| def __init__(self, api_key: str): |
| self.models = { |
| "gpt-4o-mini": OpenAIClient("gpt-4o-mini", api_key), |
| "gpt-4o": OpenAIClient("gpt-4o", api_key), |
| } |
| |
| async def extract(self, content: str, plan: ExtractionPlan, schema: Dict[str, Any]) -> ExtractionResult: |
| start_time = asyncio.get_event_loop().time() |
| results = {} |
| confidence_scores = {} |
| stage_results = [] |
| |
| logger.info(f"Starting extraction with {len(plan.stages)} stages") |
| |
| for i, stage in enumerate(plan.stages): |
| logger.info(f"Executing stage {i+1}/{len(plan.stages)}: {stage.name}") |
| |
| if not self._dependencies_satisfied(stage.dependencies, results): |
| logger.warning(f"Dependencies not satisfied for stage {stage.name}, skipping") |
| continue |
| |
| context = self._build_context(content, results, stage) |
| model_name = plan.model_assignments.get(stage.name, "gpt-4o") |
| model = self.models[model_name] |
| |
| prompt = self._create_extraction_prompt(context, stage.schema_subset, results) |
| |
| response, confidence = await model.complete(prompt, max_tokens=4000) |
| stage_data = self._parse_response(response, stage.fields) |
| |
| results.update(stage_data) |
| for field in stage.fields: |
| confidence_scores[field] = confidence * (0.9 if field in stage_data else 0.3) |
| |
| stage_results.append({ |
| "stage": stage.name, |
| "extracted_fields": list(stage_data.keys()), |
| "confidence": confidence, |
| "model": model_name, |
| "processing_time": 0.5 |
| }) |
| |
| processing_time = asyncio.get_event_loop().time() - start_time |
| |
| return ExtractionResult( |
| data=results, |
| confidence_scores=confidence_scores, |
| stage_results=stage_results, |
| metadata={ |
| "total_stages": len(plan.stages), |
| "estimated_cost": plan.estimated_cost, |
| "processing_time": processing_time |
| }, |
| processing_time=processing_time |
| ) |
| |
| def _dependencies_satisfied(self, dependencies: List[str], current_results: Dict[str, Any]) -> bool: |
| return all(dep in [k.split('.')[0] for k in current_results.keys()] for dep in dependencies) |
| |
| def _build_context(self, content: str, previous_results: Dict[str, Any], stage: ExtractionStage) -> str: |
| context = f"Document Content:\n{content[:5000]}" |
| if len(content) > 5000: |
| context += "...[truncated]" |
| |
| if previous_results: |
| context += f"\n\nPreviously Extracted Data:\n{json.dumps(previous_results, indent=2)[:1000]}" |
| |
| return context |
| |
| def _create_extraction_prompt(self, context: str, schema: Dict[str, Any], previous_results: Dict[str, Any]) -> str: |
| schema_properties = schema.get('properties', {}) |
| required_fields = schema.get('required', []) |
| |
| field_descriptions = [] |
| for field_name, field_def in schema_properties.items(): |
| if isinstance(field_def, dict): |
| field_type = field_def.get('type', 'string') |
| is_required = field_name in required_fields |
| status = "REQUIRED" if is_required else "optional" |
| field_descriptions.append(f"- {field_name} ({field_type}) [{status}]") |
| |
| previous_context = "" |
| if previous_results: |
| previous_context = f"\n\nPreviously extracted data:\n{json.dumps(previous_results, indent=2)}" |
| |
| return f"""Extract ALL specified fields from the document content according to the JSON schema. |
| |
| DOCUMENT CONTENT: |
| {context[:4000]} |
| |
| REQUIRED OUTPUT FIELDS: |
| {chr(10).join(field_descriptions)} |
| |
| SCHEMA STRUCTURE: |
| {json.dumps(schema, indent=2)}{previous_context} |
| |
| CRITICAL INSTRUCTIONS: |
| 1. Extract ALL fields specified in the schema properties |
| 2. For arrays, extract ALL items found in the content |
| 3. For objects, extract ALL nested properties |
| 4. Use null only if data truly cannot be found |
| 5. Maintain exact schema structure and types |
| 6. Output ONLY valid JSON, no explanations |
| |
| JSON OUTPUT:""" |
| |
| def _parse_response(self, response: str, expected_fields: List[str]) -> Dict[str, Any]: |
| try: |
| cleaned_response = response.strip() |
| |
| if not cleaned_response.startswith('{'): |
| json_start = cleaned_response.find('{') |
| if json_start != -1: |
| cleaned_response = cleaned_response[json_start:] |
| |
| if not cleaned_response.endswith('}'): |
| json_end = cleaned_response.rfind('}') |
| if json_end != -1: |
| cleaned_response = cleaned_response[:json_end + 1] |
| |
| data = json.loads(cleaned_response) |
| |
| if isinstance(data, dict): |
| return data |
| else: |
| logger.warning("Response is not a dictionary") |
| return {} |
| |
| except json.JSONDecodeError as e: |
| logger.warning(f"JSON decode error: {e}") |
| |
| try: |
| import re |
| json_pattern = r'\{(?:[^{}]|{(?:[^{}]|{[^{}]*})*})*\}' |
| matches = re.findall(json_pattern, response, re.DOTALL) |
| |
| for match in matches: |
| try: |
| data = json.loads(match) |
| if isinstance(data, dict) and data: |
| return data |
| except: |
| continue |
| |
| except Exception as e: |
| logger.warning(f"Regex parsing failed: {e}") |
| |
| logger.error("All JSON parsing attempts failed") |
| return {} |
|
|
| class QualityAssessor: |
| def assess_extraction(self, result: ExtractionResult, schema: Dict[str, Any]) -> QualityReport: |
| schema_compliance = self._validate_against_schema(result.data, schema) |
| field_scores = result.confidence_scores.copy() |
| consistency_score = self._check_consistency(result.data) |
| |
| required_fields = schema.get('required', []) |
| total_expected_fields = len(schema.get('properties', {})) |
| extracted_fields = len([k for k, v in result.data.items() if v is not None]) |
| |
| completeness_score = extracted_fields / total_expected_fields if total_expected_fields > 0 else 0 |
| |
| if field_scores: |
| avg_field_confidence = sum(field_scores.values()) / len(field_scores) |
| else: |
| avg_field_confidence = 0 |
| |
| overall_confidence = completeness_score * 0.6 + schema_compliance * 0.3 + consistency_score * 0.1 |
| overall_confidence = min(overall_confidence, 1.0) |
| |
| review_flags = self._generate_review_flags(field_scores, schema_compliance, overall_confidence, required_fields, result.data, total_expected_fields, extracted_fields) |
| review_time = self._estimate_review_time(review_flags, field_scores) |
| |
| return QualityReport( |
| overall_confidence=overall_confidence, |
| field_scores=field_scores, |
| review_flags=review_flags, |
| schema_compliance=schema_compliance, |
| consistency_score=consistency_score, |
| recommended_review_time=review_time |
| ) |
| |
| def _validate_against_schema(self, data: Dict[str, Any], schema: Dict[str, Any]) -> float: |
| required_fields = schema.get('required', []) |
| properties = schema.get('properties', {}) |
| |
| required_present = sum(1 for field in required_fields if field in data and data[field] is not None) |
| required_compliance = required_present / len(required_fields) if required_fields else 1.0 |
| |
| type_errors = 0 |
| total_fields = 0 |
| for field, value in data.items(): |
| if field in properties: |
| total_fields += 1 |
| expected_type = properties[field].get('type') |
| if expected_type and not self._check_type(value, expected_type): |
| type_errors += 1 |
| |
| type_compliance = 1.0 - (type_errors / total_fields) if total_fields > 0 else 1.0 |
| |
| return (required_compliance * 0.7 + type_compliance * 0.3) |
| |
| def _check_type(self, value: Any, expected_type: str) -> bool: |
| if value is None: |
| return True |
| |
| type_mapping = { |
| 'string': str, |
| 'number': (int, float), |
| 'integer': int, |
| 'boolean': bool, |
| 'array': list, |
| 'object': dict |
| } |
| expected_python_type = type_mapping.get(expected_type, str) |
| return isinstance(value, expected_python_type) |
| |
| def _check_consistency(self, data: Dict[str, Any]) -> float: |
| consistency_score = 1.0 |
| |
| if 'email' in data and data['email']: |
| if '@' not in str(data['email']): |
| consistency_score -= 0.1 |
| |
| if 'startDate' in data and 'endDate' in data: |
| try: |
| if data['startDate'] and data['endDate']: |
| if str(data['startDate']) > str(data['endDate']): |
| consistency_score -= 0.15 |
| except: |
| pass |
| |
| if isinstance(data, dict): |
| for key, value in data.items(): |
| if isinstance(value, list): |
| for item in value: |
| if isinstance(item, dict): |
| consistency_score *= self._check_consistency(item) |
| elif isinstance(value, dict): |
| consistency_score *= self._check_consistency(value) |
| |
| return max(0.7, consistency_score) |
| |
| def _generate_review_flags(self, field_scores: Dict[str, float], schema_compliance: float, overall_confidence: float, required_fields: List[str], extracted_data: Dict[str, Any], total_expected: int, extracted_count: int) -> List[str]: |
| flags = [] |
| |
| completeness_rate = extracted_count / total_expected if total_expected > 0 else 0 |
| |
| if completeness_rate < 0.5: |
| flags.append("incomplete_extraction") |
| elif completeness_rate < 0.8: |
| flags.append("partial_extraction") |
| |
| if overall_confidence < 0.6: |
| flags.append("low_quality") |
| elif overall_confidence < 0.8: |
| flags.append("moderate_quality") |
| |
| if schema_compliance < 0.7: |
| flags.append("schema_violations") |
| |
| missing_required = [field for field in required_fields if field not in extracted_data or extracted_data[field] is None] |
| if missing_required: |
| flags.append(f"missing_required_fields") |
| |
| empty_fields = [k for k, v in extracted_data.items() if v is None or v == ""] |
| if len(empty_fields) > total_expected * 0.3: |
| flags.append("many_empty_fields") |
| |
| return flags |
| |
| def _estimate_review_time(self, review_flags: List[str], field_scores: Dict[str, float]) -> int: |
| if not review_flags: |
| return 0 |
| |
| low_confidence_count = len([score for score in field_scores.values() if score < 0.7]) |
| base_time = 5 |
| field_time = low_confidence_count * 2 |
| |
| return min(base_time + field_time, 60) |
|
|
| class StructuredExtractionSystem: |
| def __init__(self, api_key: str): |
| self.schema_analyzer = SchemaAnalyzer() |
| self.document_processor = DocumentProcessor() |
| self.extraction_engine = ExtractionEngine(api_key) |
| self.quality_assessor = QualityAssessor() |
| |
| async def extract_structured_data( |
| self, |
| content: str, |
| schema: Dict[str, Any], |
| options: Optional[Dict[str, Any]] = None |
| ) -> Dict[str, Any]: |
| start_time = datetime.now() |
| |
| logger.info("Starting structured data extraction") |
| logger.info(f"Content length: {len(content)} characters") |
| |
| complexity = self.schema_analyzer.analyze_complexity(schema) |
| logger.info(f"Schema complexity: Tier {complexity.complexity_tier}") |
| |
| plan = self.schema_analyzer.create_extraction_plan(schema, complexity) |
| logger.info(f"Extraction plan: {len(plan.stages)} stages") |
| |
| chunks = self.document_processor.process_document(content, schema) |
| logger.info(f"Document chunks: {len(chunks)}") |
| |
| result = await self.extraction_engine.extract(chunks[0], plan, schema) |
| quality = self.quality_assessor.assess_extraction(result, schema) |
| |
| processing_time = (datetime.now() - start_time).total_seconds() |
| |
| logger.info(f"Extraction completed in {processing_time:.2f} seconds") |
| logger.info(f"Overall confidence: {quality.overall_confidence:.3f}") |
| |
| return { |
| "data": result.data, |
| "confidence_scores": result.confidence_scores, |
| "overall_confidence": quality.overall_confidence, |
| "review_flags": quality.review_flags, |
| "extraction_metadata": { |
| "complexity_tier": complexity.complexity_tier, |
| "stages_executed": len(plan.stages), |
| "estimated_cost": plan.estimated_cost, |
| "actual_processing_time": processing_time, |
| "schema_compliance": quality.schema_compliance, |
| "recommended_review_time": quality.recommended_review_time |
| } |
| } |