ivanhoang's picture
Update app.py
e68ec86 verified
raw
history blame
6.29 kB
import gradio as gr
import pandas as pd
from datetime import datetime
import json
import io
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
# --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH CHO CPU) ---
device = "cpu" # Ép chạy trên CPU để đảm bảo
print(f"Đang sử dụng thiết bị: {device}")
# 1. Tải mô hình OCR (Không nén)
print("Đang tải mô hình OCR (Florence-2-base)...")
ocr_model_id = "microsoft/Florence-2-base"
ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
ocr_model = AutoModelForCausalLM.from_pretrained(
ocr_model_id,
device_map=device, # Chạy trên CPU
torch_dtype=torch.float32, # Dùng float32 cho CPU
trust_remote_code=True,
attn_implementation="eager"
)
print("Tải xong mô hình OCR.")
# 2. Tải mô hình LLM (Không nén)
print("Đang tải mô hình LLM (Meta Llama 3 8B)...")
llm_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Dùng phiên bản gốc
llm_pipeline = pipeline(
"text-generation",
model=llm_model_id,
model_kwargs={"torch_dtype": torch.float32}, # Dùng float32 cho CPU
device=0 if device == "cuda" else -1, # -1 để pipeline dùng CPU
)
print("Tải xong mô hình LLM.")
# --- CÁC HÀM XỬ LÝ (GIỮ NGUYÊN) ---
def run_ocr(image: Image.Image) -> str:
if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
prompt = "<OCR>"
inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = oocr_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=2048,
num_beams=3
)
generated_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
parsed_text = ocr_processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
return parsed_text['<OCR>']
def extract_order_from_text(text: str) -> dict:
prompt = f"""You are an expert assistant for extracting order information from unstructured text. Based on the following text, extract the information and return it as a valid JSON object. The JSON object must contain: "ten_khach_hang": The customer's name (null if not found). "danh_sach_hang": An array of items. Each item must have: "ten_hang", "so_luong" (as a number), "don_vi", "ma_hang" (null if not found), and "ghi_chu" (null if not found). Output only the JSON object, with no additional text or explanation. --- Text Content --- {text} --- End Text Content ---"""
messages = [{"role": "system", "content": "You are an assistant that only outputs valid JSON."}, {"role": "user", "content": prompt},]
terminators = [llm_pipeline.tokenizer.eos_token_id, llm_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
outputs = llm_pipeline(messages, max_new_tokens=1024, eos_token_id=terminators, do_sample=False)
response_text = outputs[0]["generated_text"][-1]['content']
try: return json.loads(response_text)
except json.JSONDecodeError: return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text}
def create_excel_file(order_data: dict):
if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None
flat_data = [{'Khách hàng': order_data.get('ten_khach_hang', 'N/A'), **item} for item in order_data['danh_sach_hang']]
df = pd.DataFrame(flat_data)
output = io.BytesIO()
with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"don_hang_{timestamp}.xlsx"
return (filename, output.getvalue())
def process_image_and_extract(image):
try:
if image is None: return "Vui lòng dán ảnh vào.", None, None
print("Bắt đầu OCR...")
extracted_text = run_ocr(image)
if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
print(f"Văn bản OCR: {extracted_text}")
print("Bắt đầu trích xuất LLM...")
order_data = extract_order_from_text(extracted_text)
if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
print(f"Dữ liệu trích xuất: {order_data}")
excel_info = create_excel_file(order_data)
df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
if excel_info:
filename, filebytes = excel_info
with open(filename, "wb") as f: f.write(filebytes)
return extracted_text, df_display, filename
else: return extracted_text, df_display, None
except Exception as e:
# Bắt lỗi và hiển thị trong giao diện để dễ gỡ lỗi
import traceback
error_str = str(e)
traceback_str = traceback.format_exc()
print(traceback_str)
return f"Lỗi nghiêm trọng: {error_str}", None, None
# --- XÂY DỰNG GIAO DIỆN GRADIO (GIỮ NGUYÊN) ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ Ảnh chụp màn hình")
gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Dán ảnh chụp màn hình vào đây", type="pil", sources=["clipboard", "upload"])
process_btn = gr.Button("Xử lý", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Kết quả trích xuất")
output_table = gr.DataFrame(label="Chi tiết đơn hàng")
output_excel = gr.File(label="Tải file Excel")
gr.Markdown("### Văn bản đọc được từ ảnh (OCR)")
output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False)
process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel])
app.launch(debug=True)