import os import gc import torch import gradio as gr from PIL import Image from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from peft import PeftModel # Force sync for debugging if needed os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # --- Configuration --- base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct" adapter_id = "A-M-R-A-G/Basira" hf_token = os.getenv("token_HF") # --- Model Loading --- print("Loading base model...") model = Qwen2_5_VLForConditionalGeneration.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, token=hf_token ) print("Loading processor...") processor = AutoProcessor.from_pretrained( base_model_id, token=hf_token ) processor.tokenizer.padding_side = "right" print("Loading and applying adapter...") # Using the direct model load to bypass the PEFT KeyError bug model = PeftModel.from_pretrained(model, adapter_id) model.eval() print("Model loaded successfully!") # --- The Inference Function --- def perform_ocr_on_image(image_input: Image.Image) -> str: """ Takes a PIL image and returns the transcribed Arabic text. """ if image_input is None: return "Please upload an image." try: # Format the prompt using the chat template messages = [ { "role": "user", "content": [ {"type": "image", "image": image_input}, {"type": "text", "text": ( "Analyze the input image and detect all Arabic text. " "Output only the extracted text—verbatim and in its original script—" "without any added commentary, translation, punctuation or formatting. " "Present each line of text as plain UTF-8 strings, with no extra characters or words." )}, ], } ] # Apply template text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Prepare inputs inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device) # Generate prediction with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=512) # Trim the input tokens from the output to get only the response generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] cleaned_response = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Clean up memory gc.collect() torch.cuda.empty_cache() return cleaned_response.strip() except Exception as e: print(f"An error occurred during inference: {e}") return f"An error occurred: {str(e)}" # --- Create and Launch the Gradio Interface --- demo = gr.Interface( fn=perform_ocr_on_image, inputs=gr.Image(type="pil", label="Upload Arabic Document Image"), outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True), title="Basira: Fine-Tuned Qwen-VL for Arabic OCR", description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.", allow_flagging="never" ) if __name__ == "__main__": demo.launch()