Spaces:
Build error
Build error
| import gradio as gr | |
| import pandas as pd | |
| from datetime import datetime | |
| import json | |
| import io | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForCausalLM, pipeline | |
| # --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH CHO CPU) --- | |
| device = "cpu" # Ép chạy trên CPU để đảm bảo | |
| print(f"Đang sử dụng thiết bị: {device}") | |
| # 1. Tải mô hình OCR (Không nén) | |
| print("Đang tải mô hình OCR (Florence-2-base)...") | |
| ocr_model_id = "microsoft/Florence-2-base" | |
| ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True) | |
| ocr_model = AutoModelForCausalLM.from_pretrained( | |
| ocr_model_id, | |
| device_map=device, # Chạy trên CPU | |
| torch_dtype=torch.float32, # Dùng float32 cho CPU | |
| trust_remote_code=True, | |
| attn_implementation="eager" | |
| ) | |
| print("Tải xong mô hình OCR.") | |
| # 2. Tải mô hình LLM (Không nén) | |
| print("Đang tải mô hình LLM (Meta Llama 3 8B)...") | |
| llm_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Dùng phiên bản gốc | |
| llm_pipeline = pipeline( | |
| "text-generation", | |
| model=llm_model_id, | |
| model_kwargs={"torch_dtype": torch.float32}, # Dùng float32 cho CPU | |
| device=0 if device == "cuda" else -1, # -1 để pipeline dùng CPU | |
| ) | |
| print("Tải xong mô hình LLM.") | |
| # --- CÁC HÀM XỬ LÝ (GIỮ NGUYÊN) --- | |
| def run_ocr(image: Image.Image) -> str: | |
| if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh." | |
| prompt = "<OCR>" | |
| inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device) | |
| generated_ids = oocr_model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=2048, | |
| num_beams=3 | |
| ) | |
| generated_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| parsed_text = ocr_processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height)) | |
| return parsed_text['<OCR>'] | |
| def extract_order_from_text(text: str) -> dict: | |
| prompt = f"""You are an expert assistant for extracting order information from unstructured text. Based on the following text, extract the information and return it as a valid JSON object. The JSON object must contain: "ten_khach_hang": The customer's name (null if not found). "danh_sach_hang": An array of items. Each item must have: "ten_hang", "so_luong" (as a number), "don_vi", "ma_hang" (null if not found), and "ghi_chu" (null if not found). Output only the JSON object, with no additional text or explanation. --- Text Content --- {text} --- End Text Content ---""" | |
| messages = [{"role": "system", "content": "You are an assistant that only outputs valid JSON."}, {"role": "user", "content": prompt},] | |
| terminators = [llm_pipeline.tokenizer.eos_token_id, llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")] | |
| outputs = llm_pipeline(messages, max_new_tokens=1024, eos_token_id=terminators, do_sample=False) | |
| response_text = outputs[0]["generated_text"][-1]['content'] | |
| try: return json.loads(response_text) | |
| except json.JSONDecodeError: return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text} | |
| def create_excel_file(order_data: dict): | |
| if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None | |
| flat_data = [{'Khách hàng': order_data.get('ten_khach_hang', 'N/A'), **item} for item in order_data['danh_sach_hang']] | |
| df = pd.DataFrame(flat_data) | |
| output = io.BytesIO() | |
| with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang') | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"don_hang_{timestamp}.xlsx" | |
| return (filename, output.getvalue()) | |
| def process_image_and_extract(image): | |
| try: | |
| if image is None: return "Vui lòng dán ảnh vào.", None, None | |
| print("Bắt đầu OCR...") | |
| extracted_text = run_ocr(image) | |
| if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None | |
| print(f"Văn bản OCR: {extracted_text}") | |
| print("Bắt đầu trích xuất LLM...") | |
| order_data = extract_order_from_text(extracted_text) | |
| if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None | |
| print(f"Dữ liệu trích xuất: {order_data}") | |
| excel_info = create_excel_file(order_data) | |
| df_display = pd.DataFrame(order_data.get('danh_sach_hang', [])) | |
| if excel_info: | |
| filename, filebytes = excel_info | |
| with open(filename, "wb") as f: f.write(filebytes) | |
| return extracted_text, df_display, filename | |
| else: return extracted_text, df_display, None | |
| except Exception as e: | |
| # Bắt lỗi và hiển thị trong giao diện để dễ gỡ lỗi | |
| import traceback | |
| error_str = str(e) | |
| traceback_str = traceback.format_exc() | |
| print(traceback_str) | |
| return f"Lỗi nghiêm trọng: {error_str}", None, None | |
| # --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ Ảnh chụp màn hình") | |
| gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="Dán ảnh chụp màn hình vào đây", type="pil", sources=["clipboard", "upload"]) | |
| process_btn = gr.Button("Xử lý", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Kết quả trích xuất") | |
| output_table = gr.DataFrame(label="Chi tiết đơn hàng") | |
| output_excel = gr.File(label="Tải file Excel") | |
| gr.Markdown("### Văn bản đọc được từ ảnh (OCR)") | |
| output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False) | |
| process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel]) | |
| app.launch(debug=True) |