File size: 6,291 Bytes
420a76d
 
 
 
 
 
 
e68ec86
420a76d
e68ec86
 
420a76d
 
e68ec86
1ad2e0c
62477e6
420a76d
 
 
e68ec86
 
62477e6
e68ec86
420a76d
 
 
e68ec86
 
 
420a76d
 
 
e68ec86
 
420a76d
 
 
1ad2e0c
 
420a76d
 
 
 
 
e68ec86
420a76d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e68ec86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420a76d
1ad2e0c
420a76d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import pandas as pd
from datetime import datetime
import json
import io
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline

# --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH CHO CPU) ---
device = "cpu" # Ép chạy trên CPU để đảm bảo
print(f"Đang sử dụng thiết bị: {device}")

# 1. Tải mô hình OCR (Không nén)
print("Đang tải mô hình OCR (Florence-2-base)...")
ocr_model_id = "microsoft/Florence-2-base"
ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
ocr_model = AutoModelForCausalLM.from_pretrained(
    ocr_model_id, 
    device_map=device, # Chạy trên CPU
    torch_dtype=torch.float32, # Dùng float32 cho CPU
    trust_remote_code=True,
    attn_implementation="eager" 
)
print("Tải xong mô hình OCR.")

# 2. Tải mô hình LLM (Không nén)
print("Đang tải mô hình LLM (Meta Llama 3 8B)...")
llm_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Dùng phiên bản gốc
llm_pipeline = pipeline(
    "text-generation",
    model=llm_model_id,
    model_kwargs={"torch_dtype": torch.float32}, # Dùng float32 cho CPU
    device=0 if device == "cuda" else -1, # -1 để pipeline dùng CPU
)
print("Tải xong mô hình LLM.")


# --- CÁC HÀM XỬ LÝ (GIỮ NGUYÊN) ---

def run_ocr(image: Image.Image) -> str:
    if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
    prompt = "<OCR>"
    inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
    generated_ids = oocr_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=2048,
        num_beams=3
    )
    generated_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    parsed_text = ocr_processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
    return parsed_text['<OCR>']

def extract_order_from_text(text: str) -> dict:
    prompt = f"""You are an expert assistant for extracting order information from unstructured text. Based on the following text, extract the information and return it as a valid JSON object. The JSON object must contain: "ten_khach_hang": The customer's name (null if not found). "danh_sach_hang": An array of items. Each item must have: "ten_hang", "so_luong" (as a number), "don_vi", "ma_hang" (null if not found), and "ghi_chu" (null if not found). Output only the JSON object, with no additional text or explanation. --- Text Content --- {text} --- End Text Content ---"""
    messages = [{"role": "system", "content": "You are an assistant that only outputs valid JSON."}, {"role": "user", "content": prompt},]
    terminators = [llm_pipeline.tokenizer.eos_token_id, llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    outputs = llm_pipeline(messages, max_new_tokens=1024, eos_token_id=terminators, do_sample=False)
    response_text = outputs[0]["generated_text"][-1]['content']
    try: return json.loads(response_text)
    except json.JSONDecodeError: return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text}

def create_excel_file(order_data: dict):
    if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None
    flat_data = [{'Khách hàng': order_data.get('ten_khach_hang', 'N/A'), **item} for item in order_data['danh_sach_hang']]
    df = pd.DataFrame(flat_data)
    output = io.BytesIO()
    with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang')
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"don_hang_{timestamp}.xlsx"
    return (filename, output.getvalue())

def process_image_and_extract(image):
    try:
        if image is None: return "Vui lòng dán ảnh vào.", None, None
        
        print("Bắt đầu OCR...")
        extracted_text = run_ocr(image)
        if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
        print(f"Văn bản OCR: {extracted_text}")

        print("Bắt đầu trích xuất LLM...")
        order_data = extract_order_from_text(extracted_text)
        if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
        print(f"Dữ liệu trích xuất: {order_data}")

        excel_info = create_excel_file(order_data)
        df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
        
        if excel_info:
            filename, filebytes = excel_info
            with open(filename, "wb") as f: f.write(filebytes)
            return extracted_text, df_display, filename
        else: return extracted_text, df_display, None
    except Exception as e:
        # Bắt lỗi và hiển thị trong giao diện để dễ gỡ lỗi
        import traceback
        error_str = str(e)
        traceback_str = traceback.format_exc()
        print(traceback_str)
        return f"Lỗi nghiêm trọng: {error_str}", None, None


# --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) ---

with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ Ảnh chụp màn hình")
    gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.")
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(label="Dán ảnh chụp màn hình vào đây", type="pil", sources=["clipboard", "upload"])
            process_btn = gr.Button("Xử lý", variant="primary")
        with gr.Column(scale=2):
            gr.Markdown("### Kết quả trích xuất")
            output_table = gr.DataFrame(label="Chi tiết đơn hàng")
            output_excel = gr.File(label="Tải file Excel")
            gr.Markdown("### Văn bản đọc được từ ảnh (OCR)")
            output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False)
    process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel])

app.launch(debug=True)