ivanhoang commited on
Commit
e68ec86
·
verified ·
1 Parent(s): 62477e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -27
app.py CHANGED
@@ -5,35 +5,33 @@ import json
5
  import io
6
  from PIL import Image
7
  import torch
8
- from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
9
 
10
- # --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN NHẸ HƠN) ---
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Đang sử dụng thiết bị: {device}")
13
 
14
- # 1. Tải mô hình OCR (Sử dụng phiên bản 'base' thay vì 'large')
15
  print("Đang tải mô hình OCR (Florence-2-base)...")
16
  ocr_model_id = "microsoft/Florence-2-base"
17
  ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
18
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
19
  ocr_model = AutoModelForCausalLM.from_pretrained(
20
  ocr_model_id,
21
- device_map="auto",
22
- torch_dtype=torch.bfloat16,
23
- quantization_config=quantization_config,
24
  trust_remote_code=True,
25
- attn_implementation="eager" # <-- THÊM DÒNG NÀY ĐỂ SỬA LỖI
26
  )
27
  print("Tải xong mô hình OCR.")
28
 
29
- # 2. Tải mô hình LLM (Sử dụng phiên bản đã được nén sẵn)
30
- print("Đang tải mô hình LLM (Unsloth Llama 3 8B 4-bit)...")
31
- llm_model_id = "unsloth/llama-3-8b-Instruct-bnb-4bit"
32
  llm_pipeline = pipeline(
33
  "text-generation",
34
  model=llm_model_id,
35
- model_kwargs={"torch_dtype": torch.bfloat16},
36
- device_map="auto",
37
  )
38
  print("Tải xong mô hình LLM.")
39
 
@@ -44,7 +42,7 @@ def run_ocr(image: Image.Image) -> str:
44
  if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
45
  prompt = "<OCR>"
46
  inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
47
- generated_ids = ocr_model.generate(
48
  input_ids=inputs["input_ids"],
49
  pixel_values=inputs["pixel_values"],
50
  max_new_tokens=2048,
@@ -74,18 +72,35 @@ def create_excel_file(order_data: dict):
74
  return (filename, output.getvalue())
75
 
76
  def process_image_and_extract(image):
77
- if image is None: return "Vui lòng dán ảnh vào.", None, None
78
- extracted_text = run_ocr(image)
79
- if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
80
- order_data = extract_order_from_text(extracted_text)
81
- if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
82
- excel_info = create_excel_file(order_data)
83
- df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
84
- if excel_info:
85
- filename, filebytes = excel_info
86
- with open(filename, "wb") as f: f.write(filebytes)
87
- return extracted_text, df_display, filename
88
- else: return extracted_text, df_display, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) ---
91
 
 
5
  import io
6
  from PIL import Image
7
  import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
9
 
10
+ # --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH CHO CPU) ---
11
+ device = "cpu" # Ép chạy trên CPU để đảm bảo
12
  print(f"Đang sử dụng thiết bị: {device}")
13
 
14
+ # 1. Tải mô hình OCR (Không nén)
15
  print("Đang tải mô hình OCR (Florence-2-base)...")
16
  ocr_model_id = "microsoft/Florence-2-base"
17
  ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
 
18
  ocr_model = AutoModelForCausalLM.from_pretrained(
19
  ocr_model_id,
20
+ device_map=device, # Chạy trên CPU
21
+ torch_dtype=torch.float32, # Dùng float32 cho CPU
 
22
  trust_remote_code=True,
23
+ attn_implementation="eager"
24
  )
25
  print("Tải xong mô hình OCR.")
26
 
27
+ # 2. Tải mô hình LLM (Không nén)
28
+ print("Đang tải mô hình LLM (Meta Llama 3 8B)...")
29
+ llm_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Dùng phiên bản gốc
30
  llm_pipeline = pipeline(
31
  "text-generation",
32
  model=llm_model_id,
33
+ model_kwargs={"torch_dtype": torch.float32}, # Dùng float32 cho CPU
34
+ device=0 if device == "cuda" else -1, # -1 để pipeline dùng CPU
35
  )
36
  print("Tải xong mô hình LLM.")
37
 
 
42
  if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
43
  prompt = "<OCR>"
44
  inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
45
+ generated_ids = oocr_model.generate(
46
  input_ids=inputs["input_ids"],
47
  pixel_values=inputs["pixel_values"],
48
  max_new_tokens=2048,
 
72
  return (filename, output.getvalue())
73
 
74
  def process_image_and_extract(image):
75
+ try:
76
+ if image is None: return "Vui lòng dán ảnh vào.", None, None
77
+
78
+ print("Bắt đầu OCR...")
79
+ extracted_text = run_ocr(image)
80
+ if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
81
+ print(f"Văn bản OCR: {extracted_text}")
82
+
83
+ print("Bắt đầu trích xuất LLM...")
84
+ order_data = extract_order_from_text(extracted_text)
85
+ if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
86
+ print(f"Dữ liệu trích xuất: {order_data}")
87
+
88
+ excel_info = create_excel_file(order_data)
89
+ df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
90
+
91
+ if excel_info:
92
+ filename, filebytes = excel_info
93
+ with open(filename, "wb") as f: f.write(filebytes)
94
+ return extracted_text, df_display, filename
95
+ else: return extracted_text, df_display, None
96
+ except Exception as e:
97
+ # Bắt lỗi và hiển thị trong giao diện để dễ gỡ lỗi
98
+ import traceback
99
+ error_str = str(e)
100
+ traceback_str = traceback.format_exc()
101
+ print(traceback_str)
102
+ return f"Lỗi nghiêm trọng: {error_str}", None, None
103
+
104
 
105
  # --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) ---
106