ivanhoang commited on
Commit
3d59ca3
·
verified ·
1 Parent(s): d05e8fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -36
app.py CHANGED
@@ -5,44 +5,49 @@ import json
5
  import io
6
  from PIL import Image
7
  import torch
8
- from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
 
9
 
10
- # --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH CHO CPU) ---
11
- device = "cpu" # Ép chạy trên CPU để đảm bảo
12
  print(f"Đang sử dụng thiết bị: {device}")
13
 
14
- # 1. Tải mô hình OCR (Không nén)
15
  print("Đang tải mô hình OCR (Florence-2-base)...")
16
  ocr_model_id = "microsoft/Florence-2-base"
17
  ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
18
  ocr_model = AutoModelForCausalLM.from_pretrained(
19
  ocr_model_id,
20
- device_map=device, # Chạy trên CPU
21
- torch_dtype=torch.float32, # Dùng float32 cho CPU
22
  trust_remote_code=True,
23
  attn_implementation="eager"
24
  )
25
  print("Tải xong mô hình OCR.")
26
 
27
- # 2. Tải hình LLM (Không nén)
28
- print("Đang tải mô hình LLM (Meta Llama 3 8B)...")
29
- llm_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Dùng phiên bản gốc
30
- llm_pipeline = pipeline(
31
- "text-generation",
32
- model=llm_model_id,
33
- model_kwargs={"torch_dtype": torch.float32}, # Dùng float32 cho CPU
34
- device=0 if device == "cuda" else -1, # -1 để pipeline dùng CPU
 
 
 
 
35
  )
36
  print("Tải xong mô hình LLM.")
37
 
38
 
39
- # --- CÁC HÀM XỬ LÝ (GIỮ NGUYÊN) ---
40
 
41
  def run_ocr(image: Image.Image) -> str:
42
  if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
43
  prompt = "<OCR>"
44
  inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
45
- generated_ids = oocr_model.generate(
46
  input_ids=inputs["input_ids"],
47
  pixel_values=inputs["pixel_values"],
48
  max_new_tokens=2048,
@@ -53,17 +58,41 @@ def run_ocr(image: Image.Image) -> str:
53
  return parsed_text['<OCR>']
54
 
55
  def extract_order_from_text(text: str) -> dict:
56
- prompt = f"""You are an expert assistant for extracting order information from unstructured text. Based on the following text, extract the information and return it as a valid JSON object. The JSON object must contain: "ten_khach_hang": The customer's name (null if not found). "danh_sach_hang": An array of items. Each item must have: "ten_hang", "so_luong" (as a number), "don_vi", "ma_hang" (null if not found), and "ghi_chu" (null if not found). Output only the JSON object, with no additional text or explanation. --- Text Content --- {text} --- End Text Content ---"""
57
- messages = [{"role": "system", "content": "You are an assistant that only outputs valid JSON."}, {"role": "user", "content": prompt},]
58
- terminators = [llm_pipeline.tokenizer.eos_token_id, llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
59
- outputs = llm_pipeline(messages, max_new_tokens=1024, eos_token_id=terminators, do_sample=False)
60
- response_text = outputs[0]["generated_text"][-1]['content']
61
- try: return json.loads(response_text)
62
- except json.JSONDecodeError: return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def create_excel_file(order_data: dict):
65
  if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None
66
- flat_data = [{'Khách hàng': order_data.get('ten_khach_hang', 'N/A'), **item} for item in order_data['danh_sach_hang']]
 
 
 
 
 
 
 
 
 
 
 
67
  df = pd.DataFrame(flat_data)
68
  output = io.BytesIO()
69
  with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang')
@@ -74,36 +103,25 @@ def create_excel_file(order_data: dict):
74
  def process_image_and_extract(image):
75
  try:
76
  if image is None: return "Vui lòng dán ảnh vào.", None, None
77
-
78
- print("Bắt đầu OCR...")
79
  extracted_text = run_ocr(image)
80
  if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
81
- print(f"Văn bản OCR: {extracted_text}")
82
-
83
- print("Bắt đầu trích xuất LLM...")
84
  order_data = extract_order_from_text(extracted_text)
85
  if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
86
- print(f"Dữ liệu trích xuất: {order_data}")
87
-
88
  excel_info = create_excel_file(order_data)
89
  df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
90
-
91
  if excel_info:
92
  filename, filebytes = excel_info
93
  with open(filename, "wb") as f: f.write(filebytes)
94
  return extracted_text, df_display, filename
95
  else: return extracted_text, df_display, None
96
  except Exception as e:
97
- # Bắt lỗi và hiển thị trong giao diện để dễ gỡ lỗi
98
  import traceback
99
  error_str = str(e)
100
  traceback_str = traceback.format_exc()
101
  print(traceback_str)
102
  return f"Lỗi nghiêm trọng: {error_str}", None, None
103
 
104
-
105
  # --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) ---
106
-
107
  with gr.Blocks(theme=gr.themes.Soft()) as app:
108
  gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ Ảnh chụp màn hình")
109
  gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.")
@@ -118,6 +136,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
118
  gr.Markdown("### Văn bản đọc được từ ảnh (OCR)")
119
  output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False)
120
  process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel])
121
-
122
 
123
  app.launch(debug=True)
 
5
  import io
6
  from PIL import Image
7
  import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
+ from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM # Thư viện mới
10
 
11
+ # --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH NHẤT CHO CPU) ---
12
+ device = "cpu"
13
  print(f"Đang sử dụng thiết bị: {device}")
14
 
15
+ # 1. Tải mô hình OCR (Giữ nguyên, đã ổn định)
16
  print("Đang tải mô hình OCR (Florence-2-base)...")
17
  ocr_model_id = "microsoft/Florence-2-base"
18
  ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
19
  ocr_model = AutoModelForCausalLM.from_pretrained(
20
  ocr_model_id,
21
+ device_map=device,
22
+ torch_dtype=torch.float32,
23
  trust_remote_code=True,
24
  attn_implementation="eager"
25
  )
26
  print("Tải xong mô hình OCR.")
27
 
28
+ # 2. THAY ĐỔI LỚN: DÙNG CTRANSFORMERS ĐỂ TẢI GGUF
29
+ print("Đang tải mô hình LLM (Llama-3-8B GGUF for CPU)...")
30
+ # Sử dụng phiên bản GGUF được tối ưu cho CPU
31
+ llm_model_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
32
+ llm_model_file = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf" # File quantization 4-bit, cân bằng tốt
33
+ llm = CAutoModelForCausalLM.from_pretrained(
34
+ llm_model_id,
35
+ model_file=llm_model_file,
36
+ model_type="llama",
37
+ # Cấu hình để chạy trên CPU
38
+ gpu_layers=0, # Không dùng GPU
39
+ context_length=4096
40
  )
41
  print("Tải xong mô hình LLM.")
42
 
43
 
44
+ # --- CÁC HÀM XỬ LÝ ---
45
 
46
  def run_ocr(image: Image.Image) -> str:
47
  if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
48
  prompt = "<OCR>"
49
  inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
50
+ generated_ids = ocr_model.generate(
51
  input_ids=inputs["input_ids"],
52
  pixel_values=inputs["pixel_values"],
53
  max_new_tokens=2048,
 
58
  return parsed_text['<OCR>']
59
 
60
  def extract_order_from_text(text: str) -> dict:
61
+ # Cập nhật prompt cho Llama 3 GGUF
62
+ prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
63
+ You are an expert assistant that only outputs valid JSON. Extract order information from the text. The JSON object must contain "ten_khach_hang" (string, null if not found) and "danh_sach_hang" (an array of items). Each item must have "ten_hang" (string), "so_luong" (number), "don_vi" (string), "ma_hang" (string, null if not found), and "ghi_chu" (string, null if not found).<|eot_id|><|start_header_id|>user<|end_header_id|>
64
+ Text Content:
65
+ {text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
66
+ """
67
+
68
+ response_text = llm(prompt, max_new_tokens=1024, temperature=0.1, stop=["<|eot_id|>"])
69
+
70
+ try:
71
+ # Làm sạch output thô từ GGUF model
72
+ json_str = response_text.strip()
73
+ # Tìm vị trí bắt đầu và kết thúc của JSON
74
+ start = json_str.find('{')
75
+ end = json_str.rfind('}') + 1
76
+ if start != -1 and end != 0:
77
+ json_str = json_str[start:end]
78
+ return json.loads(json_str)
79
+ except json.JSONDecodeError:
80
+ return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text}
81
 
82
  def create_excel_file(order_data: dict):
83
  if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None
84
+ # Đổi tên key cho phù hợp với DataFrame
85
+ flat_data = []
86
+ customer = order_data.get('ten_khach_hang', 'N/A')
87
+ for item in order_data['danh_sach_hang']:
88
+ flat_data.append({
89
+ 'Khách hàng': customer,
90
+ 'Mã hàng': item.get('ma_hang'),
91
+ 'Tên hàng': item.get('ten_hang'),
92
+ 'Số lượng': item.get('so_luong'),
93
+ 'Đơn vị': item.get('don_vi'),
94
+ 'Ghi chú': item.get('ghi_chu')
95
+ })
96
  df = pd.DataFrame(flat_data)
97
  output = io.BytesIO()
98
  with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang')
 
103
  def process_image_and_extract(image):
104
  try:
105
  if image is None: return "Vui lòng dán ảnh vào.", None, None
 
 
106
  extracted_text = run_ocr(image)
107
  if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
 
 
 
108
  order_data = extract_order_from_text(extracted_text)
109
  if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
 
 
110
  excel_info = create_excel_file(order_data)
111
  df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
 
112
  if excel_info:
113
  filename, filebytes = excel_info
114
  with open(filename, "wb") as f: f.write(filebytes)
115
  return extracted_text, df_display, filename
116
  else: return extracted_text, df_display, None
117
  except Exception as e:
 
118
  import traceback
119
  error_str = str(e)
120
  traceback_str = traceback.format_exc()
121
  print(traceback_str)
122
  return f"Lỗi nghiêm trọng: {error_str}", None, None
123
 
 
124
  # --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) ---
 
125
  with gr.Blocks(theme=gr.themes.Soft()) as app:
126
  gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ Ảnh chụp màn hình")
127
  gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.")
 
136
  gr.Markdown("### Văn bản đọc được từ ảnh (OCR)")
137
  output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False)
138
  process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel])
 
139
 
140
  app.launch(debug=True)