field_semantic_mapping / engine_vision.py
Tanishq Salkar
initial visual mapping code added to hf
db81e28
import json
import asyncio
import base64
import uuid
import time
import random
from openai import AsyncOpenAI
import config
import utils_geometry as utils
from schema_definitions import VALID_FIELD_TYPES
client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
#updated the code to push
# ==========================================
# WORKER: PROCESS SINGLE BATCH (With Retries)
# ==========================================
async def process_single_batch(semaphore, doc, page_num, batch_idx, batch_fields, global_context):
async with semaphore:
prompt_items = []
for f in batch_fields:
anchors = f["debug_anchors"]
prompt_items.append(
f"- Box ID {f['temp_id']}:\n"
f" Spatial Hints -> Left: '{anchors['left']}' | Above: '{anchors['above']}'\n"
f" PDF Type Hint: {f.get('ft', 'text')}"
)
# B. Render Image
img_bytes = await asyncio.to_thread(utils.render_hollow_debug_image, doc, page_num, batch_fields)
if not img_bytes: return []
# Save debug artifact
tag = f"{page_num}_batch_{batch_idx}"
utils.save_debug_image(img_bytes, tag)
b64_img = base64.b64encode(img_bytes).decode('utf-8')
# C. System Prompt (UPDATED with section_context)
valid_types_str = ", ".join(VALID_FIELD_TYPES)
system_prompt = f"""
You are an expert Legal Document Processor.
CONTEXT: Real Estate Contract. Global Context: "{global_context}"
TASK: Analyze the Neon Green Boxes (IDs {batch_fields[0]['temp_id']} to {batch_fields[-1]['temp_id']}).
OUTPUT RULES:
For each box, return JSON with:
1. "visual_evidence": Text closest to the box.
2. "section_context": The BOLD HEADER or SECTION TITLE this field belongs to (e.g. "Purchase Price", "Property Description", "Closing Date").
3. "final_label": Precise natural label (e.g. "Purchase Price", "Seller Signature").
4. "role": Who fills this out? Choose ONLY from:
[Buyer, Seller, Agent, Broker, President, Reviewer, Disclosing Party, Receiving Party, N/A].
- If ambiguous, infer from section header (e.g. "Tenant's Signature" -> "Tenant").
- If strictly administrative (e.g. "Office Use Only"), return "System".
5. "detected_type": MUST be one of [{valid_types_str}].
- If it looks like money ($), use "dollar".
- If it looks like a date, use "date".
- If it's a signature, use "signature".
INPUT DATA:
{chr(10).join(prompt_items)}
Return JSON: {{ "fields": [ {{ "box_id": 1, ... }} ] }}
"""
# D. Retry Logic (Restored from your original code)
MAX_RETRIES = 5
BASE_DELAY = 2
batch_results = []
page_height = doc[page_num].rect.height
for attempt in range(MAX_RETRIES):
try:
response = await client.chat.completions.create(
model="gpt-4o", # Use gpt-4o for best vision
response_format={"type": "json_object"},
messages=[
{"role": "user", "content": [
{"type": "text", "text": system_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
]}
],
temperature=0.0
)
content = response.choices[0].message.content
parsed = json.loads(content)
utils.save_debug_json(parsed, f"{tag}_vision_response")
results_map = {item["box_id"]: item for item in parsed.get("fields", [])}
for f in batch_fields:
res = results_map.get(f["temp_id"], {})
label = res.get("final_label", "")
# Fallback Geometry Logic
if not label or label == "Unknown":
anchors = f["debug_anchors"]
label = anchors["left"] if anchors["left"] else (anchors["above"] if anchors["above"] else "Unknown Field")
norm_rect = utils.normalize_bbox_to_top_left(f["bbox"], page_height)
batch_results.append({
"id": f.get("name", str(uuid.uuid4())[:8]),
"temp_id": f["temp_id"],
"label": label,
"section": res.get("section_context", "General Information"), # <--- Capturing Section Context
"role": res.get("role", "System"),
"detected_type": res.get("detected_type", "shortText"),
"uiType": "checkbox" if f.get("ft") == "/Btn" else "text",
"page": page_num,
"rect": {
"x": norm_rect["x0"], "y": norm_rect["y0"],
"width": norm_rect["x1"] - norm_rect["x0"], "height": norm_rect["y1"] - norm_rect["y0"]
},
"debug_evidence": res.get("visual_evidence", "N/A")
})
break # Success, exit retry loop
except Exception as e:
error_msg = str(e)
if "429" in error_msg or "Rate limit" in error_msg:
wait_time = (BASE_DELAY * (2 ** attempt)) + (random.random() * 0.5)
print(f"Rate Limit ({tag}). Waiting {wait_time:.2f}s...")
await asyncio.sleep(wait_time) # Use await sleep for async!
else:
print(f"Error {tag}: {e}")
break
return batch_results
#
# ==========================================
# ORCHESTRATOR: PROCESS PAGE
# ==========================================
async def process_page_smart(semaphore, doc, page_num, fields, global_context):
page = doc[page_num]
page_words = utils.get_words_from_page(page)
page_height = page.rect.height
# 1. Pre-calc anchors
for idx, f in enumerate(fields):
f["temp_id"] = idx + 1
f["debug_anchors"] = utils.calculate_smart_anchors(f["bbox"], page_words, page_height)
# 2. Create Batches
batches = [fields[i:i + config.VISION_BATCH_SIZE] for i in range(0, len(fields), config.VISION_BATCH_SIZE)]
print(f"📄 Page {page_num}: Queuing {len(batches)} batches for {len(fields)} fields...")
# 3. Spawn Parallel Tasks (Restored Concurrency)
tasks = []
for batch_idx, batch_fields in enumerate(batches):
task = asyncio.create_task(
process_single_batch(semaphore, doc, page_num, batch_idx, batch_fields, global_context)
)
tasks.append(task)
# 4. Gather Results
results = await asyncio.gather(*tasks)
return [item for sublist in results for item in sublist]