A-M-R-A-G's picture
Update app.py
574930d verified
import os
import gc
import torch
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
# Force sync for debugging if needed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# --- Configuration ---
base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
adapter_id = "A-M-R-A-G/Basira"
hf_token = os.getenv("token_HF")
# --- Model Loading ---
print("Loading base model...")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
token=hf_token
)
print("Loading processor...")
processor = AutoProcessor.from_pretrained(
base_model_id,
token=hf_token
)
processor.tokenizer.padding_side = "right"
print("Loading and applying adapter...")
# Using the direct model load to bypass the PEFT KeyError bug
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()
print("Model loaded successfully!")
# --- The Inference Function ---
def perform_ocr_on_image(image_input: Image.Image) -> str:
"""
Takes a PIL image and returns the transcribed Arabic text.
"""
if image_input is None:
return "Please upload an image."
try:
# Format the prompt using the chat template
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_input},
{"type": "text", "text": (
"Analyze the input image and detect all Arabic text. "
"Output only the extracted text—verbatim and in its original script—"
"without any added commentary, translation, punctuation or formatting. "
"Present each line of text as plain UTF-8 strings, with no extra characters or words."
)},
],
}
]
# Apply template
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Prepare inputs
inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device)
# Generate prediction
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512)
# Trim the input tokens from the output to get only the response
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
cleaned_response = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Clean up memory
gc.collect()
torch.cuda.empty_cache()
return cleaned_response.strip()
except Exception as e:
print(f"An error occurred during inference: {e}")
return f"An error occurred: {str(e)}"
# --- Create and Launch the Gradio Interface ---
demo = gr.Interface(
fn=perform_ocr_on_image,
inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()