vachaspathi commited on
Commit
86f6949
·
verified ·
1 Parent(s): 9484c37

Update ai_engine.py

Browse files
Files changed (1) hide show
  1. ai_engine.py +41 -58
ai_engine.py CHANGED
@@ -17,7 +17,6 @@ except:
17
  model = None
18
 
19
  def get_metadata(file_obj):
20
- """Extracts file clues."""
21
  try:
22
  name = os.path.basename(file_obj)
23
  size = os.path.getsize(file_obj)
@@ -27,68 +26,65 @@ def get_metadata(file_obj):
27
  return {"filename": "unknown", "extension": "", "size_kb": 0}
28
 
29
  def perform_ocr(file_obj):
30
- if file_obj is None: return "", None
31
  try:
32
- # extract metadata before processing
33
  meta = get_metadata(file_obj)
34
-
35
  if meta["filename"].lower().endswith(".pdf"):
36
  image = convert_from_path(file_obj, first_page=1, last_page=1)[0]
37
  else:
38
  image = Image.open(file_obj).convert("RGB")
39
-
40
  text = pytesseract.image_to_string(image)
41
  return text, image, meta
42
  except: return "", None, {}
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def fallback_classifier(text, filename):
45
- """
46
- Rule-based classifier if AI fails.
47
- """
48
  combined = (text + " " + filename).lower()
49
-
50
- if "invoice" in combined or "inv-" in combined: return "invoice"
51
- if "estimate" in combined or "quote" in combined: return "estimate"
52
- if "credit note" in combined: return "credit_note"
53
- if "purchase order" in combined or "po-" in combined: return "purchase_order"
54
- if "bill" in combined or "payment due" in combined: return "bill"
55
  if "receipt" in combined: return "expense"
56
-
57
- return "unknown"
58
 
59
  def extract_intelligent_json(text, metadata):
60
- """
61
- Combines OCR + Metadata -> AI -> JSON
62
- """
63
  if not model: return {}
64
 
65
- # Inject Metadata into System Prompt
66
  prompt = f"""<|im_start|>system
67
- You are a Document Classifier. Use the Filename and Text to identify the document type.
68
-
69
- VALID TYPES: ["invoice", "bill", "estimate", "credit_note", "purchase_order", "expense"]
70
 
71
- RULES:
72
- 1. If filename contains 'INV', it is an 'invoice'.
73
- 2. If text mentions 'Purchase Order', it is a 'purchase_order'.
74
- 3. Extract the Vendor/Customer Name and Dates carefully.
75
-
76
- OUTPUT JSON FORMAT:
77
  {{
78
  "doc_type": "invoice",
79
- "confidence": "high",
80
  "data": {{
81
- "contact_name": "...",
82
  "date": "YYYY-MM-DD",
83
- "reference_number": "...",
84
  "total": 0.00,
85
- "line_items": [ {{"name": "...", "description": "...", "rate": 0, "quantity": 1}} ]
86
  }}
87
  }}
88
  <|im_end|>
89
  <|im_start|>user
90
- METADATA: {json.dumps(metadata)}
91
- DOCUMENT TEXT:
92
  {text[:1500]}
93
  <|im_end|>
94
  <|im_start|>assistant
@@ -98,27 +94,14 @@ def extract_intelligent_json(text, metadata):
98
  inputs = tokenizer(prompt, return_tensors="pt")
99
  out = model.generate(**inputs, max_new_tokens=400, temperature=0.1)
100
 
101
- try:
102
- # Extract JSON block using Regex (More robust than split)
103
- full_response = tokenizer.decode(out[0])
104
- json_match = re.search(r"```json\s*(\{.*?\})\s*```", full_response, re.DOTALL)
105
-
106
- if json_match:
107
- data = json.loads(json_match.group(1))
108
- else:
109
- # Fallback: Try finding the first { and last }
110
- start = full_response.find("{")
111
- end = full_response.rfind("}") + 1
112
- data = json.loads(full_response[start:end])
113
-
114
- # Double Check Classification
115
- if data.get("doc_type") == "unknown":
116
- data["doc_type"] = fallback_classifier(text, metadata.get("filename", ""))
117
-
118
- return data
119
 
120
- except Exception as e:
121
- print(f"AI Parsing Error: {e}")
122
- # Hard Fallback
123
- guessed_type = fallback_classifier(text, metadata.get("filename", ""))
124
- return {"doc_type": guessed_type, "data": {}}
 
17
  model = None
18
 
19
  def get_metadata(file_obj):
 
20
  try:
21
  name = os.path.basename(file_obj)
22
  size = os.path.getsize(file_obj)
 
26
  return {"filename": "unknown", "extension": "", "size_kb": 0}
27
 
28
  def perform_ocr(file_obj):
29
+ if file_obj is None: return "", None, {}
30
  try:
 
31
  meta = get_metadata(file_obj)
 
32
  if meta["filename"].lower().endswith(".pdf"):
33
  image = convert_from_path(file_obj, first_page=1, last_page=1)[0]
34
  else:
35
  image = Image.open(file_obj).convert("RGB")
 
36
  text = pytesseract.image_to_string(image)
37
  return text, image, meta
38
  except: return "", None, {}
39
 
40
+ def repair_json(json_str):
41
+ """CRITICAL FIX: Extracts the largest valid JSON object from messy text."""
42
+ if not json_str: return {}
43
+
44
+ # Strategy 1: Direct Load
45
+ try: return json.loads(json_str)
46
+ except: pass
47
+
48
+ # Strategy 2: Extract between first { and last }
49
+ try:
50
+ start = json_str.find('{')
51
+ end = json_str.rfind('}') + 1
52
+ if start != -1 and end != 0:
53
+ clean = json_str[start:end]
54
+ return json.loads(clean)
55
+ except: pass
56
+
57
+ return {}
58
+
59
  def fallback_classifier(text, filename):
 
 
 
60
  combined = (text + " " + filename).lower()
61
+ if "invoice" in combined: return "invoice"
62
+ if "estimate" in combined: return "estimate"
63
+ if "bill" in combined: return "bill"
 
 
 
64
  if "receipt" in combined: return "expense"
65
+ return "invoice" # Default to invoice
 
66
 
67
  def extract_intelligent_json(text, metadata):
 
 
 
68
  if not model: return {}
69
 
 
70
  prompt = f"""<|im_start|>system
71
+ Extract JSON data. Valid doc_types: ["invoice", "bill", "estimate", "expense"].
 
 
72
 
73
+ OUTPUT FORMAT:
 
 
 
 
 
74
  {{
75
  "doc_type": "invoice",
 
76
  "data": {{
77
+ "vendor_name": "Name or 'Unknown'",
78
  "date": "YYYY-MM-DD",
79
+ "reference_number": "REF-123",
80
  "total": 0.00,
81
+ "line_items": [ {{"name": "Item", "description": "Desc", "rate": 0, "quantity": 1}} ]
82
  }}
83
  }}
84
  <|im_end|>
85
  <|im_start|>user
86
+ FILE: {metadata.get('filename')}
87
+ CONTENT:
88
  {text[:1500]}
89
  <|im_end|>
90
  <|im_start|>assistant
 
94
  inputs = tokenizer(prompt, return_tensors="pt")
95
  out = model.generate(**inputs, max_new_tokens=400, temperature=0.1)
96
 
97
+ raw_output = tokenizer.decode(out[0])
98
+
99
+ # Use the new Repair Function
100
+ data = repair_json(raw_output)
101
+
102
+ # If repair failed or empty, use heuristics
103
+ if not data or "doc_type" not in data:
104
+ doc_type = fallback_classifier(text, metadata.get('filename'))
105
+ data = {"doc_type": doc_type, "data": {"vendor_name": "Unknown"}}
 
 
 
 
 
 
 
 
 
106
 
107
+ return data