ivanhoang's picture
Update app.py
a31ce93 verified
raw
history blame
6.21 kB
import gradio as gr
import pandas as pd
from datetime import datetime
import json
import io
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
# --- CẤU HÌNH VÀ TẢI MÔ HÌNH (PHIÊN BẢN ỔN ĐỊNH NHẤT CHO CPU) ---
device = "cpu"
print(f"Đang sử dụng thiết bị: {device}")
# 1. Tải mô hình OCR (Giữ nguyên, đã ổn định)
print("Đang tải mô hình OCR (Florence-2-base)...")
ocr_model_id = "microsoft/Florence-2-base"
ocr_processor = AutoProcessor.from_pretrained(ocr_model_id, trust_remote_code=True)
ocr_model = AutoModelForCausalLM.from_pretrained(
ocr_model_id,
device_map=device,
torch_dtype=torch.float32,
trust_remote_code=True,
attn_implementation="eager"
)
print("Tải xong mô hình OCR.")
# 2. DÙNG CTRANSFORMERS ĐỂ TẢI GGUF (ĐỔI NHÀ CUNG CẤP)
print("Đang tải mô hình LLM (Llama-3-8B GGUF for CPU)...")
# SỬ DỤNG PHIÊN BẢN GGUF TỪ MỘT NHÀ CUNG CẤP KHÁC
llm_model_id = "bartowski/Meta-Llama-3-8B-Instruct-GGUF"
llm_model_file = "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
llm = CAutoModelForCausalLM.from_pretrained(
llm_model_id,
model_file=llm_model_file,
model_type="llama",
gpu_layers=0,
context_length=4096
)
print("Tải xong mô hình LLM.")
# --- CÁC HÀM XỬ LÝ ---
def run_ocr(image: Image.Image) -> str:
if image is None: return "Lỗi: Vui lòng cung cấp hình ảnh."
prompt = "<OCR>"
inputs = ocr_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = ocr_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=2048,
num_beams=3
)
generated_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
parsed_text = ocr_processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
return parsed_text['<OCR>']
def extract_order_from_text(text: str) -> dict:
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert assistant that only outputs valid JSON. Extract order information from the text. The JSON object must contain "ten_khach_hang" (string, null if not found) and "danh_sach_hang" (an array of items). Each item must have "ten_hang" (string), "so_luong" (number), "don_vi" (string), "ma_hang" (string, null if not found), and "ghi_chu" (string, null if not found).<|eot_id|><|start_header_id|>user<|end_header_id|>
Text Content:
{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
response_text = llm(prompt, max_new_tokens=1024, temperature=0.1, stop=["<|eot_id|>"])
try:
json_str = response_text.strip()
start = json_str.find('{')
end = json_str.rfind('}') + 1
if start != -1 and end != 0:
json_str = json_str[start:end]
return json.loads(json_str)
except json.JSONDecodeError:
return {"error": "AI trả về định dạng không hợp lệ", "raw_response": response_text}
def create_excel_file(order_data: dict):
if not order_data or "danh_sach_hang" not in order_data or not order_data["danh_sach_hang"]: return None
flat_data = []
customer = order_data.get('ten_khach_hang', 'N/A')
for item in order_data['danh_sach_hang']:
flat_data.append({
'Khách hàng': customer, 'Mã hàng': item.get('ma_hang'),
'Tên hàng': item.get('ten_hang'), 'Số lượng': item.get('so_luong'),
'Đơn vị': item.get('don_vi'), 'Ghi chú': item.get('ghi_chu')
})
df = pd.DataFrame(flat_data)
output = io.BytesIO()
with pd.ExcelWriter(output, engine='openpyxl') as writer: df.to_excel(writer, index=False, sheet_name='DonHang')
timestamp = datetime.now().strftime("%Ym%d_%H%M%S")
filename = f"don_hang_{timestamp}.xlsx"
return (filename, output.getvalue())
def process_image_and_extract(image):
try:
if image is None: return "Vui lòng dán ảnh vào.", None, None
extracted_text = run_ocr(image)
if not extracted_text.strip(): return "Không đọc được chữ từ hình ảnh.", None, None
order_data = extract_order_from_text(extracted_text)
if "error" in order_data: return extracted_text, f"Lỗi từ AI: {order_data['error']}\nPhản hồi gốc: {order_data['raw_response']}", None
excel_info = create_excel_file(order_data)
df_display = pd.DataFrame(order_data.get('danh_sach_hang', []))
if excel_info:
filename, filebytes = excel_info
with open(filename, "wb") as f: f.write(filebytes)
return extracted_text, df_display, filename
else: return extracted_text, df_display, None
except Exception as e:
import traceback
error_str = str(e)
traceback_str = traceback.format_exc()
print(traceback_str)
return f"Lỗi nghiêm trọng: {error_str}", None, None
# --- XÂY DỰNG GIAO DIỆN GRADIO ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# Ứng dụng Trích xuất Đơn hàng từ Ảnh chụp màn hình")
gr.Markdown("Chụp màn hình email/tin nhắn đặt hàng, sau đó dán (Ctrl+V) vào ô bên dưới và nhấn 'Xử lý'.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Dán ảnh chụp màn hình vào đây", type="pil", sources=["clipboard", "upload"])
process_btn = gr.Button("Xử lý", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Kết quả trích xuất")
output_table = gr.DataFrame(label="Chi tiết đơn hàng")
output_excel = gr.File(label="Tải file Excel")
gr.Markdown("### Văn bản đọc được từ ảnh (OCR)")
output_text = gr.Textbox(label="Text from Image", lines=10, interactive=False)
process_btn.click(fn=process_image_and_extract, inputs=image_input, outputs=[output_text, output_table, output_excel])
app.launch(debug=True)