vachaspathi commited on
Commit
dcb9f42
·
verified ·
1 Parent(s): c40ad39

Update ai_engine.py

Browse files
Files changed (1) hide show
  1. ai_engine.py +40 -12
ai_engine.py CHANGED
@@ -1,4 +1,3 @@
1
- # ai_engine.py
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import pytesseract
@@ -8,14 +7,13 @@ import os
8
  import json
9
  import config
10
 
11
- # Load Model Once
12
  print(">>> Loading AI Model...")
13
  try:
14
  tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
15
  model = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True)
16
  except:
17
  model = None
18
- print("❌ Model Failed to Load")
19
 
20
  def perform_ocr(file_obj):
21
  if file_obj is None: return "", None
@@ -26,17 +24,47 @@ def perform_ocr(file_obj):
26
  else:
27
  image = Image.open(file_obj).convert("RGB")
28
  return pytesseract.image_to_string(image), image
29
- except:
30
- return "", None
31
 
32
- def extract_json(text):
 
 
 
33
  if not model: return {}
34
- prompt = f"<|im_start|>user\nExtract JSON: vendor_name, invoice_date, total, item_desc\nText:\n{text[:1000]}<|im_end|>\n<|im_start|>assistant\n```json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  inputs = tokenizer(prompt, return_tensors="pt")
36
- out = model.generate(**inputs, max_new_tokens=200)
 
37
  try:
38
  json_str = tokenizer.decode(out[0]).split("```json")[1].split("```")[0].strip()
39
- data = json.loads(json_str)
40
- return data[0] if isinstance(data, list) else data
41
- except:
42
- return {}
 
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import pytesseract
 
7
  import json
8
  import config
9
 
10
+ # Load Model
11
  print(">>> Loading AI Model...")
12
  try:
13
  tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID)
14
  model = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True)
15
  except:
16
  model = None
 
17
 
18
  def perform_ocr(file_obj):
19
  if file_obj is None: return "", None
 
24
  else:
25
  image = Image.open(file_obj).convert("RGB")
26
  return pytesseract.image_to_string(image), image
27
+ except: return "", None
 
28
 
29
+ def extract_intelligent_json(text):
30
+ """
31
+ Classifies the document and extracts relevant fields.
32
+ """
33
  if not model: return {}
34
+
35
+ # Robust prompt instructing the AI to classify and format
36
+ prompt = f"""<|im_start|>system
37
+ Analyze the document text.
38
+ 1. CLASSIFY the type as one of: ["invoice", "estimate", "credit_note", "expense", "contact", "purchase_order"].
39
+ 2. EXTRACT data based on the type.
40
+
41
+ OUTPUT FORMAT (JSON ONLY):
42
+ {{
43
+ "doc_type": "invoice",
44
+ "data": {{
45
+ "vendor_name": "...",
46
+ "date": "YYYY-MM-DD",
47
+ "reference_number": "...",
48
+ "total": 0.00,
49
+ "line_items": [ {{"name": "...", "rate": 0, "quantity": 1}} ]
50
+ }}
51
+ }}
52
+ <|im_end|>
53
+ <|im_start|>user
54
+ DOCUMENT TEXT:
55
+ {text[:1500]}
56
+ <|im_end|>
57
+ <|im_start|>assistant
58
+ ```json
59
+ """
60
+
61
  inputs = tokenizer(prompt, return_tensors="pt")
62
+ out = model.generate(**inputs, max_new_tokens=350, temperature=0.1)
63
+
64
  try:
65
  json_str = tokenizer.decode(out[0]).split("```json")[1].split("```")[0].strip()
66
+ return json.loads(json_str)
67
+ except Exception as e:
68
+ print(f"AI Error: {e}")
69
+ # Fallback default
70
+ return {"doc_type": "unknown", "data": {}}