A-M-R-A-G commited on
Commit
574930d
·
verified ·
1 Parent(s): 28e990c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -1,11 +1,12 @@
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
- import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from peft import PeftModel
6
- import gc
7
- import os
8
 
 
9
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
10
 
11
  # --- Configuration ---
@@ -15,7 +16,6 @@ hf_token = os.getenv("token_HF")
15
 
16
  # --- Model Loading ---
17
  print("Loading base model...")
18
- # Added device_map="auto" and trust_remote_code
19
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
  base_model_id,
21
  torch_dtype=torch.float16,
@@ -29,67 +29,78 @@ processor = AutoProcessor.from_pretrained(
29
  base_model_id,
30
  token=hf_token
31
  )
 
32
 
33
  print("Loading and applying adapter...")
34
- # FIX: Use the 'model' attribute specifically if Peft struggles with the wrapper
35
- # Or simply ensure the base model is fully loaded before wrapping
36
  model = PeftModel.from_pretrained(model, adapter_id)
37
- model.eval() # Set to evaluation mode
 
38
  print("Model loaded successfully!")
39
 
40
  # --- The Inference Function ---
41
  def perform_ocr_on_image(image_input: Image.Image) -> str:
 
 
 
42
  if image_input is None:
43
  return "Please upload an image."
44
-
45
  try:
46
- # 1. Format the prompt
47
  messages = [
48
  {
49
  "role": "user",
50
  "content": [
51
  {"type": "image", "image": image_input},
52
- {"type": "text", "text": "Analyze the input image and detect all Arabic text. Output only the extracted text—verbatim and in its original script."},
 
 
 
 
 
53
  ],
54
  }
55
  ]
56
-
57
- # 2. Apply chat template
58
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
 
60
- # 3. Prepare inputs correctly for Qwen2.5-VL
61
- # Note: Some versions require 'images' to be a list even if it's one image
62
  inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device)
63
 
64
- # 4. Generate
65
  with torch.no_grad():
66
- # Use the underlying model's generation to avoid PEFT wrapper conflicts
67
  generated_ids = model.generate(**inputs, max_new_tokens=512)
68
-
69
- # 5. Decode only the NEW tokens to avoid manual string splitting
70
  generated_ids_trimmed = [
71
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
72
  ]
 
73
  cleaned_response = processor.batch_decode(
74
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
75
  )[0]
76
 
77
- # Clean up
78
  gc.collect()
79
  torch.cuda.empty_cache()
 
80
  return cleaned_response.strip()
81
-
82
  except Exception as e:
83
  print(f"An error occurred during inference: {e}")
84
  return f"An error occurred: {str(e)}"
85
 
86
- # --- Interface ---
87
  demo = gr.Interface(
88
  fn=perform_ocr_on_image,
89
  inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
90
  outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
91
  title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
92
- description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR.",
93
  allow_flagging="never"
94
  )
95
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
  import gradio as gr
5
  from PIL import Image
 
6
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
7
  from peft import PeftModel
 
 
8
 
9
+ # Force sync for debugging if needed
10
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
11
 
12
  # --- Configuration ---
 
16
 
17
  # --- Model Loading ---
18
  print("Loading base model...")
 
19
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
  base_model_id,
21
  torch_dtype=torch.float16,
 
29
  base_model_id,
30
  token=hf_token
31
  )
32
+ processor.tokenizer.padding_side = "right"
33
 
34
  print("Loading and applying adapter...")
35
+ # Using the direct model load to bypass the PEFT KeyError bug
 
36
  model = PeftModel.from_pretrained(model, adapter_id)
37
+ model.eval()
38
+
39
  print("Model loaded successfully!")
40
 
41
  # --- The Inference Function ---
42
  def perform_ocr_on_image(image_input: Image.Image) -> str:
43
+ """
44
+ Takes a PIL image and returns the transcribed Arabic text.
45
+ """
46
  if image_input is None:
47
  return "Please upload an image."
48
+
49
  try:
50
+ # Format the prompt using the chat template
51
  messages = [
52
  {
53
  "role": "user",
54
  "content": [
55
  {"type": "image", "image": image_input},
56
+ {"type": "text", "text": (
57
+ "Analyze the input image and detect all Arabic text. "
58
+ "Output only the extracted text—verbatim and in its original script—"
59
+ "without any added commentary, translation, punctuation or formatting. "
60
+ "Present each line of text as plain UTF-8 strings, with no extra characters or words."
61
+ )},
62
  ],
63
  }
64
  ]
65
+
66
+ # Apply template
67
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
68
 
69
+ # Prepare inputs
 
70
  inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device)
71
 
72
+ # Generate prediction
73
  with torch.no_grad():
 
74
  generated_ids = model.generate(**inputs, max_new_tokens=512)
75
+
76
+ # Trim the input tokens from the output to get only the response
77
  generated_ids_trimmed = [
78
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
79
  ]
80
+
81
  cleaned_response = processor.batch_decode(
82
+ generated_ids_trimmed,
83
+ skip_special_tokens=True,
84
+ clean_up_tokenization_spaces=False
85
  )[0]
86
 
87
+ # Clean up memory
88
  gc.collect()
89
  torch.cuda.empty_cache()
90
+
91
  return cleaned_response.strip()
92
+
93
  except Exception as e:
94
  print(f"An error occurred during inference: {e}")
95
  return f"An error occurred: {str(e)}"
96
 
97
+ # --- Create and Launch the Gradio Interface ---
98
  demo = gr.Interface(
99
  fn=perform_ocr_on_image,
100
  inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
101
  outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
102
  title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
103
+ description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.",
104
  allow_flagging="never"
105
  )
106