ivanhoang commited on
Commit
420a76d
·
verified ·
1 Parent(s): ce6b91b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datetime import datetime
4
+ import json
5
+ import io
6
+ from PIL import Image
7
+ import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
9
+
10
+ # --- CẤU HÌNH VÀ TẢI MÔ HÌNH ---
11
+ # Sử dụng GPU nếu có, nếu không thì dùng CPU
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Đang sử dụng thiết bị: {device}")
14
+
15
+ # 1. Tải mô hình OCR (Optical Character Recognition)
16
+ print("Đang tải mô hình OCR (Florence-2)...")
17
+ ocr_model_id = "microsoft/Florence-2-large"
18
+ ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
19
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
20
+ ocr_model = AutoModelForCausalLM.from_pretrained(
21
+ ocr_model_id,
22
+ device_map=device,
23
+ torch_dtype=torch.bfloat16,
24
+ quantization_config=quantization_config,
25
+ trust_remote_code=True
26
+ )
27
+ print("Tải xong mô hình OCR.")
28
+
29
+ # 2. Tải mô hình LLM (Language Model)
30
+ print("Đang tải mô hình LLM (Meta Llama 3 8B)...")
31
+ llm_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
32
+ from transformers import pipeline
33
+ llm_pipeline = pipeline(
34
+ "text-generation",
35
+ model=llm_model_id,
36
+ model_kwargs={"torch_dtype": torch.bfloat16},
37
+ device_map="auto",
38
+ )
39
+ print("Tải xong mô hình LLM.")
40
+
41
+ # --- CÁC HÀM XỬ LÝ ---
42
+
43
+ def run_ocr(image: Image.Image) -> str:
44
+ """Hàm chạy OCR để đọc chữ từ ảnh"""
45
+ if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
46
+ prompt = "<OCR>"
47
+ inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
48
+ generated_ids = ocr_model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ pixel_values=inputs["pixel_values"],
51
+ max_new_tokens=2048,
52
+ num_beams=3
53
+ )
54
+ generated_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
55
+ parsed_text = ocr_processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
56
+ return parsed_text['<OCR>']
57
+
58
+ def extract_order_from_text(text: str) -> dict:
59
+ """Hàm chạy LLM để trích xuất thông tin từ văn bản"""
60
+ prompt = f"""You are an expert assistant for extracting order information from unstructured text. Based on the following text, extract the information and return it as a valid JSON object. The JSON object must contain: "ten_khach_hang": The customer's name (null if not found). "danh_sach_hang": An array of items. Each item must have: "ten_hang", "so_luong" (as a number), "don_vi", "ma_hang" (null if not found), and "ghi_chu" (null if not found). Output only the JSON object, with no additional text or explanation. --- Text Content --- {text} --- End Text Content ---"""
61
+ messages = [{"role": "system", "content": "You are an assistant that only outputs valid JSON."}, {"role": "user", "content": prompt},]
62
+ terminators = [llm_pipeline.tokenizer.eos_token_id, llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
63
+ outputs = llm_pipeline(messages, max_new_tokens=1024, eos_token_id=terminators, do_sample=False)
64
+ response_text = outputs[0]["generated_text"][-1]['content']
65
+ try: return json.loads(response_text)
66
+ except json.JSONDecodeError: return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text}
67
+
68
+ def create_excel_file(order_data: dict):
69
+ """Hàm tạo file Excel từ dữ liệu đã trích xuất"""
70
+ if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None
71
+ flat_data = [{'Khách hàng': order_data.get('ten_khach_hang', 'N/A'), **item} for item in order_data['danh_sach_hang']]
72
+ df = pd.DataFrame(flat_data)
73
+ output = io.BytesIO()
74
+ with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang')
75
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
76
+ filename = f"don_hang_{timestamp}.xlsx"
77
+ return (filename, output.getvalue())
78
+
79
+ def process_image_and_extract(image):
80
+ """Hàm tổng hợp, gọi OCR rồi đến LLM"""
81
+ if image is None: return "Vui lòng dán ảnh vào.", None, None
82
+ extracted_text = run_ocr(image)
83
+ if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
84
+ order_data = extract_order_from_text(extracted_text)
85
+ if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
86
+ excel_info = create_excel_file(order_data)
87
+ df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
88
+ if excel_info:
89
+ filename, filebytes = excel_info
90
+ with open(filename, "wb") as f: f.write(filebytes)
91
+ return extracted_text, df_display, filename
92
+ else: return extracted_text, df_display, None
93
+
94
+ # --- XÂY DỰNG GIAO DIỆN GRADIO ---
95
+
96
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
97
+ gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ ���nh chụp màn hình")
98
+ gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.")
99
+ with gr.Row():
100
+ with gr.Column(scale=1):
101
+ image_input = gr.Image(label="Dán ảnh chụp màn hình vào đây", type="pil", sources=["clipboard", "upload"])
102
+ process_btn = gr.Button("Xử lý", variant="primary")
103
+ with gr.Column(scale=2):
104
+ gr.Markdown("### Kết quả trích xuất")
105
+ output_table = gr.DataFrame(label="Chi tiết đơn hàng")
106
+ output_excel = gr.File(label="Tải file Excel")
107
+ gr.Markdown("### Văn bản đọc được từ ảnh (OCR)")
108
+ output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False)
109
+ process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel])
110
+
111
+ app.launch(debug=True)