File size: 7,093 Bytes
db81e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
import asyncio
import base64
import uuid
import time
import random
from openai import AsyncOpenAI
import config
import utils_geometry as utils
from schema_definitions import VALID_FIELD_TYPES

client = AsyncOpenAI(api_key=config.OPENAI_API_KEY)
#updated the code to push

# ==========================================
# WORKER: PROCESS SINGLE BATCH (With Retries)
# ==========================================
async def process_single_batch(semaphore, doc, page_num, batch_idx, batch_fields, global_context):
    async with semaphore:
        prompt_items = []
        for f in batch_fields:
            anchors = f["debug_anchors"]
            prompt_items.append(
                f"- Box ID {f['temp_id']}:\n"
                f"  Spatial Hints -> Left: '{anchors['left']}' | Above: '{anchors['above']}'\n"
                f"  PDF Type Hint: {f.get('ft', 'text')}"
            )

        # B. Render Image
        img_bytes = await asyncio.to_thread(utils.render_hollow_debug_image, doc, page_num, batch_fields)
        if not img_bytes: return []
        
        # Save debug artifact
        tag = f"{page_num}_batch_{batch_idx}"
        utils.save_debug_image(img_bytes, tag)
        b64_img = base64.b64encode(img_bytes).decode('utf-8')

        # C. System Prompt (UPDATED with section_context)
        valid_types_str = ", ".join(VALID_FIELD_TYPES)
        
        system_prompt = f"""
        You are an expert Legal Document Processor.
        CONTEXT: Real Estate Contract. Global Context: "{global_context}"
        
        TASK: Analyze the Neon Green Boxes (IDs {batch_fields[0]['temp_id']} to {batch_fields[-1]['temp_id']}).
        
        OUTPUT RULES:
        For each box, return JSON with:
        1. "visual_evidence": Text closest to the box.
        2. "section_context": The BOLD HEADER or SECTION TITLE this field belongs to (e.g. "Purchase Price", "Property Description", "Closing Date").
        3. "final_label": Precise natural label (e.g. "Purchase Price", "Seller Signature").
        4. "role": Who fills this out? Choose ONLY from: 
           [Buyer, Seller, Agent, Broker, President, Reviewer, Disclosing Party, Receiving Party, N/A].
           - If ambiguous, infer from section header (e.g. "Tenant's Signature" -> "Tenant").
           - If strictly administrative (e.g. "Office Use Only"), return "System".
        5. "detected_type": MUST be one of [{valid_types_str}]. 
           - If it looks like money ($), use "dollar".
           - If it looks like a date, use "date".
           - If it's a signature, use "signature".
        
        INPUT DATA:
        {chr(10).join(prompt_items)}
        
        Return JSON: {{ "fields": [ {{ "box_id": 1, ... }} ] }}
        """

        # D. Retry Logic (Restored from your original code)
        MAX_RETRIES = 5
        BASE_DELAY = 2
        batch_results = []
        page_height = doc[page_num].rect.height

        for attempt in range(MAX_RETRIES):
            try:
                response = await client.chat.completions.create(
                    model="gpt-4o", # Use gpt-4o for best vision
                    response_format={"type": "json_object"},
                    messages=[
                        {"role": "user", "content": [
                            {"type": "text", "text": system_prompt},
                            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
                        ]}
                    ],
                    temperature=0.0
                )
                
                content = response.choices[0].message.content
                parsed = json.loads(content)
                utils.save_debug_json(parsed, f"{tag}_vision_response")
                
                results_map = {item["box_id"]: item for item in parsed.get("fields", [])}
                
                for f in batch_fields:
                    res = results_map.get(f["temp_id"], {})
                    label = res.get("final_label", "")
                    
                    # Fallback Geometry Logic
                    if not label or label == "Unknown":
                        anchors = f["debug_anchors"]
                        label = anchors["left"] if anchors["left"] else (anchors["above"] if anchors["above"] else "Unknown Field")

                    norm_rect = utils.normalize_bbox_to_top_left(f["bbox"], page_height)
                    
                    batch_results.append({
                        "id": f.get("name", str(uuid.uuid4())[:8]),
                        "temp_id": f["temp_id"],
                        "label": label,
                        "section": res.get("section_context", "General Information"), # <--- Capturing Section Context
                        "role": res.get("role", "System"),
                        "detected_type": res.get("detected_type", "shortText"), 
                        "uiType": "checkbox" if f.get("ft") == "/Btn" else "text",
                        "page": page_num,
                        "rect": {
                            "x": norm_rect["x0"], "y": norm_rect["y0"],
                            "width": norm_rect["x1"] - norm_rect["x0"], "height": norm_rect["y1"] - norm_rect["y0"]
                        },
                        "debug_evidence": res.get("visual_evidence", "N/A")
                    })
                break # Success, exit retry loop

            except Exception as e:
                error_msg = str(e)
                if "429" in error_msg or "Rate limit" in error_msg:
                    wait_time = (BASE_DELAY * (2 ** attempt)) + (random.random() * 0.5)
                    print(f"Rate Limit ({tag}). Waiting {wait_time:.2f}s...")
                    await asyncio.sleep(wait_time) # Use await sleep for async!
                else:
                    print(f"Error {tag}: {e}")
                    break
        
        return batch_results

#
# ==========================================
# ORCHESTRATOR: PROCESS PAGE
# ==========================================
async def process_page_smart(semaphore, doc, page_num, fields, global_context):
    page = doc[page_num]
    page_words = utils.get_words_from_page(page)
    page_height = page.rect.height
    
    # 1. Pre-calc anchors
    for idx, f in enumerate(fields):
        f["temp_id"] = idx + 1
        f["debug_anchors"] = utils.calculate_smart_anchors(f["bbox"], page_words, page_height)

    # 2. Create Batches
    batches = [fields[i:i + config.VISION_BATCH_SIZE] for i in range(0, len(fields), config.VISION_BATCH_SIZE)]
    
    print(f"📄 Page {page_num}: Queuing {len(batches)} batches for {len(fields)} fields...")

    # 3. Spawn Parallel Tasks (Restored Concurrency)
    tasks = []
    for batch_idx, batch_fields in enumerate(batches):
        task = asyncio.create_task(
            process_single_batch(semaphore, doc, page_num, batch_idx, batch_fields, global_context)
        )
        tasks.append(task)

    # 4. Gather Results
    results = await asyncio.gather(*tasks)
    return [item for sublist in results for item in sublist]