Namra-Satva commited on
Commit
554ef3f
·
verified ·
1 Parent(s): 88d3b0d

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +67 -62
model_utils.py CHANGED
@@ -1,62 +1,67 @@
1
- import cv2
2
- import pytesseract
3
- import re
4
- from PIL import Image
5
- from ultralytics import YOLO
6
-
7
- # Path to your trained YOLO model
8
- MODEL_PATH = "yolov8m_invoiceOCR.pt"
9
-
10
- # YOLO class names (order matters)
11
- class_names = [
12
- "Discount_Percentage", "Due_Date", "Email_Client", "Name_Client", "Products",
13
- "Remise", "Subtotal", "Tax", "Tax_Precentage", "Tel_Client", "billing address",
14
- "header", "invoice date", "invoice number", "shipping address", "total"
15
- ]
16
-
17
- # Load YOLOv8 model
18
- model = YOLO(MODEL_PATH)
19
-
20
- def initialize_data_dict():
21
- return {label: [] if label == "Products" else "" for label in class_names}
22
-
23
- def parse_products(raw_text):
24
- structured = []
25
- lines = raw_text.split('\n')
26
- for line in lines:
27
- match = re.match(r"(\d+)\s+(.*)\s+([\d,]+\.\d{2})\s+([\d,]+\.\d{2})", line)
28
- if match:
29
- qty, desc, unit_price, amount = match.groups()
30
- structured.append({
31
- "qty": qty,
32
- "description": desc.strip(),
33
- "unit_price": unit_price,
34
- "amount": amount
35
- })
36
- elif line.strip():
37
- structured.append({"raw": line.strip()})
38
- return structured
39
-
40
- def extract_invoice_data_from_image(image_path: str):
41
- image_bgr = cv2.imread(image_path)
42
- image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
43
- pil_img = Image.fromarray(image_rgb)
44
-
45
- results = model(image_path)[0]
46
- data = initialize_data_dict()
47
-
48
- for box in results.boxes:
49
- x1, y1, x2, y2 = map(int, box.xyxy[0])
50
- cls_id = int(box.cls[0])
51
- label = class_names[cls_id]
52
-
53
- cropped_img = pil_img.crop((x1, y1, x2, y2))
54
- extracted_text = pytesseract.image_to_string(cropped_img, config='--psm 6').strip()
55
-
56
- if label == "Products" and extracted_text:
57
- structured_products = parse_products(extracted_text)
58
- data["Products"].extend(structured_products)
59
- elif extracted_text:
60
- data[label] = extracted_text
61
-
62
- return data
 
 
 
 
 
 
1
+ import cv2
2
+ import pytesseract
3
+ import re
4
+ from PIL import Image
5
+ from ultralytics import YOLO
6
+
7
+ # Path to your trained YOLO model
8
+ MODEL_PATH = "yolov8m_invoiceOCR.pt"
9
+
10
+ # YOLO class names (order matters)
11
+ class_names = [
12
+ "Discount_Percentage", "Due_Date", "Email_Client", "Name_Client", "Products",
13
+ "Remise", "Subtotal", "Tax", "Tax_Precentage", "Tel_Client", "billing address",
14
+ "header", "invoice date", "invoice number", "shipping address", "total"
15
+ ]
16
+
17
+ # Load YOLOv8 model
18
+ model = YOLO(MODEL_PATH)
19
+
20
+ def initialize_data_dict():
21
+ return {label: [] if label == "Products" else "" for label in class_names}
22
+
23
+ def parse_products(raw_text):
24
+ structured = []
25
+ lines = raw_text.split('\n')
26
+ for line in lines:
27
+ match = re.match(r"(\d+)\s+(.*)\s+([\d,]+\.\d{2})\s+([\d,]+\.\d{2})", line)
28
+ if match:
29
+ qty, desc, unit_price, amount = match.groups()
30
+ structured.append({
31
+ "qty": qty,
32
+ "description": desc.strip(),
33
+ "unit_price": unit_price,
34
+ "amount": amount
35
+ })
36
+ elif line.strip():
37
+ structured.append({
38
+ "qty": 0,
39
+ "description": line.strip(),
40
+ "unit_price": 0,
41
+ "amount": 0
42
+ })
43
+ return structured
44
+
45
+ def extract_invoice_data_from_image(image_path: str):
46
+ image_bgr = cv2.imread(image_path)
47
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
48
+ pil_img = Image.fromarray(image_rgb)
49
+
50
+ results = model(image_path)[0]
51
+ data = initialize_data_dict()
52
+
53
+ for box in results.boxes:
54
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
55
+ cls_id = int(box.cls[0])
56
+ label = class_names[cls_id]
57
+
58
+ cropped_img = pil_img.crop((x1, y1, x2, y2))
59
+ extracted_text = pytesseract.image_to_string(cropped_img, config='--psm 6').strip()
60
+
61
+ if label == "Products" and extracted_text:
62
+ structured_products = parse_products(extracted_text)
63
+ data["Products"].extend(structured_products)
64
+ elif extracted_text:
65
+ data[label] = extracted_text
66
+
67
+ return data