vachaspathi commited on
Commit
c835cd1
·
verified ·
1 Parent(s): 986c126

Update ai_engine.py

Browse files
Files changed (1) hide show
  1. ai_engine.py +121 -53
ai_engine.py CHANGED
@@ -2,106 +2,174 @@ import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import pytesseract
4
  from pdf2image import convert_from_path
5
- from PIL import Image
6
  import os
7
  import json
8
  import re
9
  import config
10
 
11
  # Load Model
12
- print(">>> Loading AI Model...")
13
  try:
14
  tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True)
16
  except:
17
  model = None
 
18
 
19
- def get_metadata(file_obj):
20
- try:
21
- name = os.path.basename(file_obj)
22
- size = os.path.getsize(file_obj)
23
- ext = name.split('.')[-1].lower()
24
- return {"filename": name, "extension": ext, "size_kb": size/1024}
25
- except:
26
- return {"filename": "unknown", "extension": "", "size_kb": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def perform_ocr(file_obj):
29
  if file_obj is None: return "", None, {}
30
  try:
31
- meta = get_metadata(file_obj)
32
- if meta["filename"].lower().endswith(".pdf"):
33
- image = convert_from_path(file_obj, first_page=1, last_page=1)[0]
 
 
 
 
34
  else:
35
- image = Image.open(file_obj).convert("RGB")
36
- text = pytesseract.image_to_string(image)
37
- return text, image, meta
38
- except: return "", None, {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def repair_json(json_str):
41
- """CRITICAL FIX: Extracts the largest valid JSON object from messy text."""
42
  if not json_str: return {}
43
-
44
- # Strategy 1: Direct Load
45
- try: return json.loads(json_str)
46
- except: pass
47
-
48
- # Strategy 2: Extract between first { and last }
49
  try:
 
50
  start = json_str.find('{')
51
  end = json_str.rfind('}') + 1
52
  if start != -1 and end != 0:
53
- clean = json_str[start:end]
54
- return json.loads(clean)
55
  except: pass
56
-
57
  return {}
58
 
59
- def fallback_classifier(text, filename):
60
- combined = (text + " " + filename).lower()
61
- if "invoice" in combined: return "invoice"
62
- if "estimate" in combined: return "estimate"
63
- if "bill" in combined: return "bill"
64
- if "receipt" in combined: return "expense"
65
- return "invoice" # Default to invoice
66
-
67
  def extract_intelligent_json(text, metadata):
68
  if not model: return {}
69
 
 
70
  prompt = f"""<|im_start|>system
71
- Extract JSON data. Valid doc_types: ["invoice", "bill", "estimate", "expense"].
 
 
 
 
 
 
72
 
73
- OUTPUT FORMAT:
74
  {{
75
- "doc_type": "invoice",
76
  "data": {{
77
- "vendor_name": "Name or 'Unknown'",
78
- "date": "YYYY-MM-DD",
79
- "reference_number": "REF-123",
80
- "total": 0.00,
81
- "line_items": [ {{"name": "Item", "description": "Desc", "rate": 0, "quantity": 1}} ]
82
  }}
83
  }}
84
  <|im_end|>
85
  <|im_start|>user
86
- FILE: {metadata.get('filename')}
87
- CONTENT:
88
- {text[:1500]}
89
  <|im_end|>
90
  <|im_start|>assistant
91
  ```json
92
  """
93
 
94
  inputs = tokenizer(prompt, return_tensors="pt")
95
- out = model.generate(**inputs, max_new_tokens=400, temperature=0.1)
96
 
97
  raw_output = tokenizer.decode(out[0])
98
-
99
- # Use the new Repair Function
100
  data = repair_json(raw_output)
101
 
102
- # If repair failed or empty, use heuristics
103
- if not data or "doc_type" not in data:
104
- doc_type = fallback_classifier(text, metadata.get('filename'))
105
- data = {"doc_type": doc_type, "data": {"vendor_name": "Unknown"}}
 
 
 
 
 
 
 
 
 
 
106
 
 
107
  return data
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import pytesseract
4
  from pdf2image import convert_from_path
5
+ from PIL import Image, ImageEnhance, ImageFilter
6
  import os
7
  import json
8
  import re
9
  import config
10
 
11
  # Load Model
12
+ print(f">>> Loading AI Model: {config.MODEL_ID}...")
13
  try:
14
  tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True)
16
  except:
17
  model = None
18
+ print("❌ Model Failed to Load")
19
 
20
+ # =====================================================
21
+ # 1. ADVANCED OCR PIPELINE
22
+ # =====================================================
23
+ def preprocess_image(image):
24
+ """
25
+ Cleans image for better OCR results:
26
+ 1. Grayscale
27
+ 2. Sharpen
28
+ 3. Increase Contrast
29
+ """
30
+ # Convert to gray
31
+ image = image.convert('L')
32
+
33
+ # Increase Contrast
34
+ enhancer = ImageEnhance.Contrast(image)
35
+ image = enhancer.enhance(2.0)
36
+
37
+ # Sharpen (helps with blurry fonts)
38
+ image = image.filter(ImageFilter.SHARPEN)
39
+
40
+ return image
41
 
42
  def perform_ocr(file_obj):
43
  if file_obj is None: return "", None, {}
44
  try:
45
+ filename = os.path.basename(file_obj)
46
+
47
+ # HIGH QUALITY CONVERSION (DPI=300)
48
+ if filename.lower().endswith(".pdf"):
49
+ # dpi=300 makes text much clearer than default 72
50
+ images = convert_from_path(file_obj, first_page=1, last_page=1, dpi=300)
51
+ original_img = images[0]
52
  else:
53
+ original_img = Image.open(file_obj).convert("RGB")
54
+
55
+ # Preprocess for Tesseract
56
+ processed_img = preprocess_image(original_img)
57
+
58
+ # Run Tesseract
59
+ text = pytesseract.image_to_string(processed_img)
60
+
61
+ # Metadata extraction
62
+ meta = {
63
+ "filename": filename,
64
+ "size_kb": os.path.getsize(file_obj)/1024
65
+ }
66
+
67
+ return text, original_img, meta
68
+ except Exception as e:
69
+ print(f"OCR Error: {e}")
70
+ return "", None, {}
71
+
72
+ # =====================================================
73
+ # 2. REGEX FALLBACKS (The "Generic Name" Fix)
74
+ # =====================================================
75
+ def regex_extract_vendor(text):
76
+ """
77
+ If AI fails, we use old-school logic to find the name.
78
+ """
79
+ lines = [l.strip() for l in text.split('\n') if len(l.strip()) > 3]
80
+
81
+ # 1. Look for "To" / "From"
82
+ for i, line in enumerate(lines):
83
+ if re.search(r'^(bill|invoice)\s*to:?$', line.lower()):
84
+ # The NEXT line is likely the customer name
85
+ if i + 1 < len(lines): return lines[i+1]
86
+
87
+ if re.search(r'^(from|vendor):?$', line.lower()):
88
+ if i + 1 < len(lines): return lines[i+1]
89
 
90
+ # 2. Top-most bold text (heuristic: usually the first or second line is the Company Name)
91
+ if len(lines) > 0:
92
+ # Ignore common headers
93
+ if "invoice" not in lines[0].lower(): return lines[0]
94
+ if len(lines) > 1: return lines[1]
95
+
96
+ return "Unknown"
97
+
98
+ def regex_extract_total(text):
99
+ # Looks for "Total $1,234.56" patterns
100
+ match = re.search(r'(?:total|amount|balance).*?([\d,]+\.\d{2})', text.lower())
101
+ if match:
102
+ try: return float(match.group(1).replace(',', ''))
103
+ except: pass
104
+ return 0.0
105
+
106
+ # =====================================================
107
+ # 3. AI EXTRACTION
108
+ # =====================================================
109
  def repair_json(json_str):
 
110
  if not json_str: return {}
 
 
 
 
 
 
111
  try:
112
+ # Find the first { and the last }
113
  start = json_str.find('{')
114
  end = json_str.rfind('}') + 1
115
  if start != -1 and end != 0:
116
+ return json.loads(json_str[start:end])
 
117
  except: pass
 
118
  return {}
119
 
 
 
 
 
 
 
 
 
120
  def extract_intelligent_json(text, metadata):
121
  if not model: return {}
122
 
123
+ # Stronger Prompt
124
  prompt = f"""<|im_start|>system
125
+ You are a financial data extractor.
126
+ TASK: Convert OCR text into JSON.
127
+
128
+ MANDATORY RULES:
129
+ 1. Extract the VENDOR_NAME (Who sent the invoice?)
130
+ 2. Extract the DOCUMENT_TYPE: ["invoice", "bill", "expense", "estimate"]
131
+ 3. Extract LINE_ITEMS.
132
 
133
+ JSON FORMAT:
134
  {{
135
+ "doc_type": "invoice",
136
  "data": {{
137
+ "vendor_name": "Acme Corp",
138
+ "date": "2024-01-01",
139
+ "reference_number": "INV-001",
140
+ "total": 100.00,
141
+ "line_items": [ {{"name": "Service", "description": "...", "rate": 100, "quantity": 1}} ]
142
  }}
143
  }}
144
  <|im_end|>
145
  <|im_start|>user
146
+ DOCUMENT TEXT:
147
+ {text[:2000]}
 
148
  <|im_end|>
149
  <|im_start|>assistant
150
  ```json
151
  """
152
 
153
  inputs = tokenizer(prompt, return_tensors="pt")
154
+ out = model.generate(**inputs, max_new_tokens=500, temperature=0.1)
155
 
156
  raw_output = tokenizer.decode(out[0])
 
 
157
  data = repair_json(raw_output)
158
 
159
+ # --- FALLBACK LAYER ---
160
+ # If AI returned empty/garbage data, overlay with Regex
161
+ if not data or "data" not in data:
162
+ data = {"doc_type": "invoice", "data": {}}
163
+
164
+ inner = data.get("data", {})
165
+
166
+ # Fix Name
167
+ if not inner.get("vendor_name") or inner["vendor_name"] == "Unknown":
168
+ inner["vendor_name"] = regex_extract_vendor(text)
169
+
170
+ # Fix Total
171
+ if not inner.get("total"):
172
+ inner["total"] = regex_extract_total(text)
173
 
174
+ data["data"] = inner
175
  return data