Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from peft import PeftModel | |
| # Force sync for debugging if needed | |
| os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
| # --- Configuration --- | |
| base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| adapter_id = "A-M-R-A-G/Basira" | |
| hf_token = os.getenv("token_HF") | |
| # --- Model Loading --- | |
| print("Loading base model...") | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| print("Loading processor...") | |
| processor = AutoProcessor.from_pretrained( | |
| base_model_id, | |
| token=hf_token | |
| ) | |
| processor.tokenizer.padding_side = "right" | |
| print("Loading and applying adapter...") | |
| # Using the direct model load to bypass the PEFT KeyError bug | |
| model = PeftModel.from_pretrained(model, adapter_id) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # --- The Inference Function --- | |
| def perform_ocr_on_image(image_input: Image.Image) -> str: | |
| """ | |
| Takes a PIL image and returns the transcribed Arabic text. | |
| """ | |
| if image_input is None: | |
| return "Please upload an image." | |
| try: | |
| # Format the prompt using the chat template | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image_input}, | |
| {"type": "text", "text": ( | |
| "Analyze the input image and detect all Arabic text. " | |
| "Output only the extracted text—verbatim and in its original script—" | |
| "without any added commentary, translation, punctuation or formatting. " | |
| "Present each line of text as plain UTF-8 strings, with no extra characters or words." | |
| )}, | |
| ], | |
| } | |
| ] | |
| # Apply template | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Prepare inputs | |
| inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device) | |
| # Generate prediction | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=512) | |
| # Trim the input tokens from the output to get only the response | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| cleaned_response = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| # Clean up memory | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return cleaned_response.strip() | |
| except Exception as e: | |
| print(f"An error occurred during inference: {e}") | |
| return f"An error occurred: {str(e)}" | |
| # --- Create and Launch the Gradio Interface --- | |
| demo = gr.Interface( | |
| fn=perform_ocr_on_image, | |
| inputs=gr.Image(type="pil", label="Upload Arabic Document Image"), | |
| outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True), | |
| title="Basira: Fine-Tuned Qwen-VL for Arabic OCR", | |
| description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |