Liviu16 commited on
Commit
d45f115
Β·
verified Β·
1 Parent(s): 41e8072

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -84,9 +84,14 @@ def process_invoice(file_info, progress=gr.Progress()):
84
  else:
85
  image = Image.open(file_info.name)
86
 
87
- # 2. Router (Identify Vendor)
88
- progress(0.3, desc="πŸ” Identifying vendor (Router)...")
89
- decision_prompt = "Identify vendor: VODAFONE, DIGI, or GENERAL. Reply with one word."
 
 
 
 
 
90
  messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": decision_prompt}]}]
91
 
92
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
@@ -96,9 +101,17 @@ def process_invoice(file_info, progress=gr.Progress()):
96
  generated_ids = model.generate(**inputs, max_new_tokens=10)
97
  raw_choice = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0].strip().upper()
98
 
 
 
 
 
 
 
 
 
99
  vendor_key = "VODAFONE" if "VODAFONE" in raw_choice else ("DIGI" if "DIGI" in raw_choice else "GENERAL")
100
 
101
- # 3. Specialist (Extract Data)
102
  progress(0.6, desc=f"πŸ€– Extracting {vendor_key} details...")
103
  schema_json = json.dumps(SCHEMAS[vendor_key], indent=2)
104
  extract_prompt = f"Extract details as JSON strictly following this schema: {schema_json}. Return ONLY valid JSON."
 
84
  else:
85
  image = Image.open(file_info.name)
86
 
87
+ # 2. Router & Validation (Identify Vendor or Reject)
88
+ progress(0.3, desc="πŸ” Validating and Identifying Vendor...")
89
+
90
+ # Updated prompt to provide an 'INVALID' exit
91
+ decision_prompt = """Analyze this image. Is it a financial invoice or receipt?
92
+ - If NO (e.g. random photo, object, landscape): Reply 'INVALID'.
93
+ - If YES: Reply ONLY with 'VODAFONE', 'DIGI', or 'GENERAL'."""
94
+
95
  messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": decision_prompt}]}]
96
 
97
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
101
  generated_ids = model.generate(**inputs, max_new_tokens=10)
102
  raw_choice = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0].strip().upper()
103
 
104
+ # VALIDATION CHECK: If model says INVALID, stop here
105
+ if "INVALID" in raw_choice:
106
+ progress(1.0, desc="❌ Invalid Document")
107
+ return image, {
108
+ "error": "Validation Failed",
109
+ "message": "The uploaded image does not appear to be an invoice. Extraction cancelled to prevent hallucinations."
110
+ }
111
+
112
  vendor_key = "VODAFONE" if "VODAFONE" in raw_choice else ("DIGI" if "DIGI" in raw_choice else "GENERAL")
113
 
114
+ # 3. Specialist (Extract Data) - Only runs for valid documents
115
  progress(0.6, desc=f"πŸ€– Extracting {vendor_key} details...")
116
  schema_json = json.dumps(SCHEMAS[vendor_key], indent=2)
117
  extract_prompt = f"Extract details as JSON strictly following this schema: {schema_json}. Return ONLY valid JSON."