A-M-R-A-G commited on
Commit
28e990c
·
verified ·
1 Parent(s): 53a6054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -40
app.py CHANGED
@@ -6,99 +6,90 @@ from peft import PeftModel
6
  import gc
7
  import os
8
 
9
- # Add this line immediately after your imports
10
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
11
 
12
  # --- Configuration ---
13
  base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
14
  adapter_id = "A-M-R-A-G/Basira"
 
15
 
16
  # --- Model Loading ---
17
  print("Loading base model...")
 
18
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
19
  base_model_id,
20
  torch_dtype=torch.float16,
21
  device_map="auto",
22
- token=os.getenv("token_HF")
 
23
  )
24
 
25
  print("Loading processor...")
26
  processor = AutoProcessor.from_pretrained(
27
  base_model_id,
28
- token=os.getenv("token_HF")
29
  )
30
- processor.tokenizer.padding_side = "right"
31
 
32
  print("Loading and applying adapter...")
 
 
33
  model = PeftModel.from_pretrained(model, adapter_id)
 
34
  print("Model loaded successfully!")
35
 
36
  # --- The Inference Function ---
37
  def perform_ocr_on_image(image_input: Image.Image) -> str:
38
- """
39
- This is the core function that Gradio will call.
40
- It takes a PIL image and returns the transcribed text string.
41
- """
42
  if image_input is None:
43
  return "Please upload an image."
44
-
45
  try:
46
- # Format the prompt using the chat template
47
  messages = [
48
  {
49
  "role": "user",
50
  "content": [
51
  {"type": "image", "image": image_input},
52
- {"type": "text", "text": (
53
- "Analyze the input image and detect all Arabic text. "
54
- "Output only the extracted text—verbatim and in its original script—"
55
- "without any added commentary, translation, punctuation or formatting. "
56
- "Present each line of text as plain UTF-8 strings, with no extra characters or words."
57
- )},
58
  ],
59
  }
60
  ]
 
 
61
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
62
 
63
- # Prepare inputs for the model
64
- inputs = processor(text=text, images=image_input, return_tensors="pt").to(model.device)
65
-
66
- # Generate prediction
67
  with torch.no_grad():
 
68
  generated_ids = model.generate(**inputs, max_new_tokens=512)
69
-
70
- # Decode the output
71
- full_response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
 
73
- # --- FIX: Post-process the response to remove the prompt ---
74
- # The model's actual output starts after the "assistant" marker.
75
- # We split the full response by this marker and take the last part.
76
- parts = full_response.split("assistant")
77
- if len(parts) > 1:
78
- # Take the last part and remove any leading/trailing whitespace
79
- cleaned_response = parts[-1].strip()
80
- else:
81
- # If the marker isn't found, return the full response as a fallback
82
- cleaned_response = full_response
83
- # --- END OF FIX ---
84
 
85
- # Clean up memory
86
  gc.collect()
87
  torch.cuda.empty_cache()
88
-
89
- return cleaned_response
90
-
91
  except Exception as e:
92
  print(f"An error occurred during inference: {e}")
93
  return f"An error occurred: {str(e)}"
94
 
95
- # --- Create and Launch the Gradio Interface ---
96
  demo = gr.Interface(
97
  fn=perform_ocr_on_image,
98
  inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
99
  outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
100
  title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
101
- description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR. Upload an image to see the transcription.",
102
  allow_flagging="never"
103
  )
104
 
 
6
  import gc
7
  import os
8
 
 
9
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
10
 
11
  # --- Configuration ---
12
  base_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
13
  adapter_id = "A-M-R-A-G/Basira"
14
+ hf_token = os.getenv("token_HF")
15
 
16
  # --- Model Loading ---
17
  print("Loading base model...")
18
+ # Added device_map="auto" and trust_remote_code
19
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
  base_model_id,
21
  torch_dtype=torch.float16,
22
  device_map="auto",
23
+ trust_remote_code=True,
24
+ token=hf_token
25
  )
26
 
27
  print("Loading processor...")
28
  processor = AutoProcessor.from_pretrained(
29
  base_model_id,
30
+ token=hf_token
31
  )
 
32
 
33
  print("Loading and applying adapter...")
34
+ # FIX: Use the 'model' attribute specifically if Peft struggles with the wrapper
35
+ # Or simply ensure the base model is fully loaded before wrapping
36
  model = PeftModel.from_pretrained(model, adapter_id)
37
+ model.eval() # Set to evaluation mode
38
  print("Model loaded successfully!")
39
 
40
  # --- The Inference Function ---
41
  def perform_ocr_on_image(image_input: Image.Image) -> str:
 
 
 
 
42
  if image_input is None:
43
  return "Please upload an image."
44
+
45
  try:
46
+ # 1. Format the prompt
47
  messages = [
48
  {
49
  "role": "user",
50
  "content": [
51
  {"type": "image", "image": image_input},
52
+ {"type": "text", "text": "Analyze the input image and detect all Arabic text. Output only the extracted text—verbatim and in its original script."},
 
 
 
 
 
53
  ],
54
  }
55
  ]
56
+
57
+ # 2. Apply chat template
58
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
+
60
+ # 3. Prepare inputs correctly for Qwen2.5-VL
61
+ # Note: Some versions require 'images' to be a list even if it's one image
62
+ inputs = processor(text=[text], images=[image_input], padding=True, return_tensors="pt").to(model.device)
63
 
64
+ # 4. Generate
 
 
 
65
  with torch.no_grad():
66
+ # Use the underlying model's generation to avoid PEFT wrapper conflicts
67
  generated_ids = model.generate(**inputs, max_new_tokens=512)
 
 
 
68
 
69
+ # 5. Decode only the NEW tokens to avoid manual string splitting
70
+ generated_ids_trimmed = [
71
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
72
+ ]
73
+ cleaned_response = processor.batch_decode(
74
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
75
+ )[0]
 
 
 
 
76
 
77
+ # Clean up
78
  gc.collect()
79
  torch.cuda.empty_cache()
80
+ return cleaned_response.strip()
81
+
 
82
  except Exception as e:
83
  print(f"An error occurred during inference: {e}")
84
  return f"An error occurred: {str(e)}"
85
 
86
+ # --- Interface ---
87
  demo = gr.Interface(
88
  fn=perform_ocr_on_image,
89
  inputs=gr.Image(type="pil", label="Upload Arabic Document Image"),
90
  outputs=gr.Textbox(label="Transcription", lines=10, show_copy_button=True),
91
  title="Basira: Fine-Tuned Qwen-VL for Arabic OCR",
92
+ description="A demo for the Qwen-VL 2.5 (3B) model, fine-tuned for enhanced Arabic OCR.",
93
  allow_flagging="never"
94
  )
95