File size: 3,545 Bytes
574930d
 
 
53a6054
 
 
 
 
574930d
53a6054
 
 
 
 
28e990c
53a6054
 
 
 
 
 
 
28e990c
 
53a6054
 
 
 
 
28e990c
53a6054
574930d
53a6054
 
574930d
53a6054
574930d
 
53a6054
 
 
 
574930d
 
 
53a6054
 
574930d
53a6054
574930d
53a6054
 
 
 
 
574930d
 
 
 
 
 
53a6054
 
 
574930d
 
53a6054
28e990c
574930d
28e990c
53a6054
574930d
53a6054
 
574930d
 
28e990c
 
 
574930d
28e990c
574930d
 
 
28e990c
53a6054
574930d
53a6054
 
574930d
28e990c
574930d
53a6054
 
 
 
574930d
53a6054
 
 
 
 
574930d
53a6054
 
 
 
931a8d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import gc
import torch
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from peft import PeftModel

# Force sync for debugging if needed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# --- Configuration ---
base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
adapter_id = "A-M-R-A-G/Basira"
hf_token = os.getenv("token_HF")

# --- Model Loading ---
print("Loading base model...")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    token=hf_token
)

print("Loading processor...")
processor = AutoProcessor.from_pretrained(
    base_model_id,
    token=hf_token
)
processor.tokenizer.padding_side = "right"

print("Loading and applying adapter...")
# Using the direct model load to bypass the PEFT KeyError bug
model = PeftModel.from_pretrained(model, adapter_id)
model.eval()

print("Model loaded successfully!")

# --- The Inference Function ---
def perform_ocr_on_image(image_input: Image.Image) -> str:
    """
    Takes a PIL image and returns the transcribed Arabic text.
    """
    if image_input is None:
        return "Please upload an image."

    try:
        # Format the prompt using the chat template
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_input},
                    {"type": "text", "text": (
                        "Analyze the input image and detect all Arabic text. "
                        "Output only the extracted text—verbatim and in its original script—"
                        "without any added commentary, translation, punctuation or formatting. "
                        "Present each line of text as plain UTF-8 strings, with no extra characters or words."
                    )},
                ],
            }
        ]

        # Apply template
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        # Prepare inputs
        inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device)

        # Generate prediction
        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=512)

        # Trim the input tokens from the output to get only the response
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        cleaned_response = processor.batch_decode(
            generated_ids_trimmed, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )[0]

        # Clean up memory
        gc.collect()
        torch.cuda.empty_cache()
        
        return cleaned_response.strip()

    except Exception as e:
        print(f"An error occurred during inference: {e}")
        return f"An error occurred: {str(e)}"

# --- Create and Launch the Gradio Interface ---
demo = gr.Interface(
    fn=perform_ocr_on_image,
    inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
    outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
    title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
    description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()