github-actions[bot] commited on
Commit
a54dc28
·
1 Parent(s): a736bf4

Sync from GitHub: 3bae2496b8a3786d399c36363516d096a3b7421b

Browse files
Files changed (1) hide show
  1. inference.py +61 -123
inference.py CHANGED
@@ -63,57 +63,50 @@ Output rules:
63
  """
64
 
65
 
66
- # Two-step Chain of Thought prompts (reasoning mode) - OPTIMIZED FOR SPEED
67
- REASONING_PROMPT = """
68
- Analyze this Indian tractor invoice and share your observations about extracting these 2 fields:
69
- 'model name' and its corresponding 'Horse Power'
70
 
71
- Think through each field:
 
 
 
 
72
 
73
  MODEL NAME:
74
- - How is the model presented? (checkbox list, handwritten field, printed text)
75
- - If checkboxes exist, which one is marked/selected?
76
- - What exact text do you see for the model?
77
- - Is it in English or regional language?
78
 
79
  HORSE POWER:
80
- - Where do you see HP mentioned?
81
- - Is it explicit (like "49 HP") or in a checkbox list?
82
- - If checkboxes, which HP value is selected?
83
- - What exact text shows the HP?
84
- - Horse power must come ONLY from explicit HP text, never from model numbers.
85
- - Horse power may appear as "HP", handwritten like "49 HP", "63hp", "HP-30".
86
-
87
- Express your observations naturally. Be specific about what you see and any uncertainties.
88
- """
89
-
90
 
91
- EXTRACTION_WITH_CONTEXT_PROMPT = """
92
- Based on the image and following analysis regarding 'model':
93
- {reasoning_output}
 
94
 
95
- Extract these fields from image and analysis:
96
 
 
97
  {{
 
98
  "dealer_name": string,
99
  "model_name": string,
100
  "horse_power": number,
101
  "asset_cost": number
102
  }}
103
 
104
- Critical rules:
105
-
106
- - Dealer name must be copied exactly from the image in the original language and spelling.
107
- - Model name must be copied exactly from the image without translation.
108
- - HP: Number only (e.g., "49 HP" 49). Use selected checkbox if applicable
109
- - ASSET COST: Final total as number (remove ₹, commas: "1,50,000" → 150000)
110
- - Checkboxes: Extract only marked options
111
-
112
- Extraction hints:
113
- - Asset cost is the total amount, usually the largest number on the page, the total amount after TAX, final price or final cost.
114
- - Dealer name is usually at the top header or company name.
115
- - Model name often appears near words like Model, Tractor, Variant.
116
- - If handwriting is unclear, make your best reasonable interpretation of the characters — but preserve language.
117
 
118
  Output ONLY valid JSON, no markdown.
119
  """
@@ -242,10 +235,10 @@ class InferenceProcessor:
242
  return output_text, latency
243
 
244
  @staticmethod
245
- def run_vlm_reasoning(image: Image.Image) -> Tuple[str, float]:
246
  """
247
- Run VLM model for Chain of Thought reasoning phase (step 1 of 2)
248
- Analyzes document structure and observes field locations
249
  """
250
  if not model_manager.is_loaded():
251
  raise RuntimeError("Models not loaded")
@@ -258,7 +251,7 @@ class InferenceProcessor:
258
  "role": "user",
259
  "content": [
260
  {"type": "image", "image": image},
261
- {"type": "text", "text": REASONING_PROMPT}
262
  ]
263
  }
264
  ]
@@ -283,8 +276,8 @@ class InferenceProcessor:
283
 
284
  start = time.time()
285
 
286
- # Generate (reduced tokens for faster processing)
287
- generated_ids = model.generate(**inputs, max_new_tokens=256)
288
 
289
  latency = time.time() - start
290
 
@@ -305,78 +298,29 @@ class InferenceProcessor:
305
  if torch.cuda.is_available():
306
  torch.cuda.empty_cache()
307
 
308
- print(f"🧠 Reasoning phase completed in {latency:.2f}s")
309
- return output_text, latency
310
-
311
- @staticmethod
312
- def run_vlm_extraction_with_context(image: Image.Image, reasoning_output: str) -> Tuple[str, float]:
313
- """
314
- Run VLM model for extraction phase (step 2 of 2) using reasoning context
315
- Extracts structured fields based on previous reasoning
316
- """
317
- if not model_manager.is_loaded():
318
- raise RuntimeError("Models not loaded")
319
-
320
- model = model_manager.vlm_model
321
- processor = model_manager.processor
322
-
323
- # Format the extraction prompt with reasoning context
324
- extraction_prompt = EXTRACTION_WITH_CONTEXT_PROMPT.format(reasoning_output=reasoning_output)
325
-
326
- messages = [
327
- {
328
- "role": "user",
329
- "content": [
330
- {"type": "image", "image": image},
331
- {"type": "text", "text": extraction_prompt}
332
- ]
333
- }
334
- ]
335
-
336
- # Apply chat template
337
- text = processor.apply_chat_template(
338
- messages,
339
- tokenize=False,
340
- add_generation_prompt=True
341
- )
342
-
343
- # Process vision input
344
- image_inputs, video_inputs = process_vision_info(messages)
345
- inputs = processor(
346
- text=[text],
347
- images=image_inputs,
348
- videos=video_inputs,
349
- padding=True,
350
- return_tensors="pt",
351
- )
352
- inputs = inputs.to("cuda")
353
-
354
- start = time.time()
355
-
356
- # Generate
357
- generated_ids = model.generate(**inputs, max_new_tokens=256)
358
-
359
- latency = time.time() - start
360
-
361
- # Decode output
362
- generated_ids_trimmed = [
363
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
364
- ]
365
- output_text = processor.batch_decode(
366
- generated_ids_trimmed,
367
- skip_special_tokens=True,
368
- clean_up_tokenization_spaces=False
369
- )
370
-
371
- output_text = output_text[0] if isinstance(output_text, list) else output_text
372
 
373
- # Clean up GPU memory
374
- del inputs, generated_ids, generated_ids_trimmed
375
- if torch.cuda.is_available():
376
- torch.cuda.empty_cache()
377
-
378
- print(f"📝 Extraction phase completed in {latency:.2f}s")
379
- return output_text, latency
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  @staticmethod
382
  def extract_json_from_output(text: str) -> Dict:
@@ -562,17 +506,11 @@ class InferenceProcessor:
562
  # Step 3: VLM Extraction (either simple or with Chain of Thought reasoning)
563
  t3 = time.time()
564
  if reasoning_mode == "reason":
565
- # Two-step Chain of Thought approach
566
- print("🧠 Using Chain of Thought reasoning mode (2-step)")
567
 
568
- # Step 3a: Reasoning phase
569
- reasoning_output, reasoning_latency = InferenceProcessor.run_vlm_reasoning(image)
570
- timing_breakdown['vlm_reasoning'] = round(reasoning_latency, 3)
571
-
572
- # Step 3b: Extraction phase with context
573
- vlm_output, extraction_latency = InferenceProcessor.run_vlm_extraction_with_context(image, reasoning_output)
574
- timing_breakdown['vlm_extraction'] = round(extraction_latency, 3)
575
- timing_breakdown['vlm_inference_total'] = round(reasoning_latency + extraction_latency, 3)
576
 
577
  # Store reasoning for debugging/transparency
578
  timing_breakdown['reasoning_output'] = reasoning_output
 
63
  """
64
 
65
 
66
+ # Combined Chain of Thought prompt (reasoning mode) - Single call with reasoning and extraction
67
+ COMBINED_REASONING_EXTRACTION_PROMPT = """
68
+ Analyze this Indian tractor invoice using Chain of Thought reasoning.
 
69
 
70
+ First, share your observations about the 4 key fields:
71
+
72
+ DEALER NAME:
73
+ - Where do you see it? (header, letterhead, stamp)
74
+ - What language? What exact text?
75
 
76
  MODEL NAME:
77
+ - How is it presented? (checkbox/handwritten/printed)
78
+ - If checkboxes, which is marked?
79
+ - What exact text do you see?
 
80
 
81
  HORSE POWER:
82
+ - Where is HP mentioned?
83
+ - Explicit text like "49 HP" or in checkbox?
84
+ - Which value is selected?
85
+ - HP must come from explicit HP text only, never from model numbers
 
 
 
 
 
 
86
 
87
+ ASSET COST:
88
+ - Where is the final total?
89
+ - Which amount is after all taxes?
90
+ - What exact amount with currency?
91
 
92
+ After reasoning, extract the fields.
93
 
94
+ Return ONLY valid JSON:
95
  {{
96
+ "reasoning": "your observations and thoughts here",
97
  "dealer_name": string,
98
  "model_name": string,
99
  "horse_power": number,
100
  "asset_cost": number
101
  }}
102
 
103
+ Rules for extraction:
104
+ - Copy dealer/model names EXACTLY in original language, don't translate
105
+ - HP as number only ("49 HP" 49), use selected checkbox
106
+ - Asset cost as number (remove ₹, commas: "1,50,000" 150000)
107
+ - Asset cost is the final total after TAX
108
+ - Dealer is usually at top header
109
+ - If handwriting unclear, make best interpretation but preserve language
 
 
 
 
 
 
110
 
111
  Output ONLY valid JSON, no markdown.
112
  """
 
235
  return output_text, latency
236
 
237
  @staticmethod
238
+ def run_vlm_reasoning_and_extraction(image: Image.Image) -> Tuple[str, str, float]:
239
  """
240
+ Run VLM model with combined Chain of Thought reasoning and extraction in single call
241
+ Returns: (reasoning_text, extraction_json_str, latency)
242
  """
243
  if not model_manager.is_loaded():
244
  raise RuntimeError("Models not loaded")
 
251
  "role": "user",
252
  "content": [
253
  {"type": "image", "image": image},
254
+ {"type": "text", "text": COMBINED_REASONING_EXTRACTION_PROMPT}
255
  ]
256
  }
257
  ]
 
276
 
277
  start = time.time()
278
 
279
+ # Generate with more tokens for combined reasoning + extraction
280
+ generated_ids = model.generate(**inputs, max_new_tokens=384)
281
 
282
  latency = time.time() - start
283
 
 
298
  if torch.cuda.is_available():
299
  torch.cuda.empty_cache()
300
 
301
+ # Parse the combined output to separate reasoning from extraction
302
+ reasoning_text = ""
303
+ extraction_json = output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ try:
306
+ # Try to parse as JSON
307
+ parsed = json.loads(output_text.strip())
308
+ if "reasoning" in parsed:
309
+ reasoning_text = parsed["reasoning"]
310
+ # Remove reasoning from output to get clean extraction JSON
311
+ extraction_dict = {k: v for k, v in parsed.items() if k != "reasoning"}
312
+ extraction_json = json.dumps(extraction_dict)
313
+ except:
314
+ # If parsing fails, try to split manually
315
+ # Look for JSON pattern
316
+ json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', output_text, re.DOTALL)
317
+ if json_match:
318
+ extraction_json = json_match.group(0)
319
+ # Everything before JSON is reasoning
320
+ reasoning_text = output_text[:json_match.start()].strip()
321
+
322
+ print(f"🧠 Combined reasoning + extraction completed in {latency:.2f}s")
323
+ return reasoning_text, extraction_json, latency
324
 
325
  @staticmethod
326
  def extract_json_from_output(text: str) -> Dict:
 
506
  # Step 3: VLM Extraction (either simple or with Chain of Thought reasoning)
507
  t3 = time.time()
508
  if reasoning_mode == "reason":
509
+ # Combined Chain of Thought: reasoning + extraction in single call
510
+ print("🧠 Using Chain of Thought reasoning mode (single call)")
511
 
512
+ reasoning_output, vlm_output, vlm_latency = InferenceProcessor.run_vlm_reasoning_and_extraction(image)
513
+ timing_breakdown['vlm_inference'] = round(vlm_latency, 3)
 
 
 
 
 
 
514
 
515
  # Store reasoning for debugging/transparency
516
  timing_breakdown['reasoning_output'] = reasoning_output