imperiusrex commited on
Commit
3e0219f
·
verified ·
1 Parent(s): eddfdb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -30
app.py CHANGED
@@ -1,24 +1,25 @@
1
  import gradio as gr
 
2
  from transformers import CLIPProcessor, CLIPModel
3
  from paddleocr import PaddleOCR, TextDetection
4
  from PIL import Image
5
- import torch
6
  import numpy as np
7
  import cv2
8
- import os
9
  import spaces
10
 
11
  # --- Global setup for models and data ---
12
- # This section runs once when the app starts.
13
- print("Initializing models...")
14
 
15
- # Load CLIP model once.
16
- # By default, Hugging Face transformers will load models to the GPU if available.
17
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
18
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
19
 
20
  # Initialize Paddle's text detection model.
21
- # REMOVED the 'use_gpu=True' argument to fix the ValueError.
22
  det_model = TextDetection(model_name="PP-OCRv5_server_det")
23
 
24
  # Candidate language phrases for detection
@@ -27,7 +28,6 @@ candidates = [
27
  "This is Telugu text",
28
  "This is Chinese text",
29
  "This is Korean text",
30
- # Add other languages as needed
31
  ]
32
 
33
  # Map detected languages to PaddleOCR language codes
@@ -38,6 +38,8 @@ lang_map = {
38
  "korean": "korean",
39
  }
40
 
 
 
41
  # --- Utility Functions ---
42
  def get_box_center(box):
43
  """Calculates the center of a bounding box."""
@@ -47,45 +49,41 @@ def get_box_center(box):
47
  center_y = sum(y_coords) / len(y_coords)
48
  return center_x, center_y
49
 
50
- # --- Main OCR Pipeline Function ---
51
- @spaces.GPU # This decorator ensures the function is executed on the assigned GPU.
52
- def ocr_pipeline(image: Image.Image) -> str:
53
  """
54
  Performs OCR on an input image using a multi-step pipeline.
55
 
56
  Args:
57
- image: A PIL Image object from the Gradio interface.
58
 
59
  Returns:
60
  A string containing the reconstructed text.
61
  """
62
- if image is None:
63
  return "No image provided."
64
 
65
  print("Starting OCR pipeline...")
66
 
67
  # Convert PIL image to a NumPy array for OpenCV and Paddle
68
- img_np = np.array(image.convert("RGB"))
69
 
70
  # Step 1: Text Detection with PaddleOCR's model
71
- # This will be fast on the H200 GPU.
72
  output = det_model.predict(img_np, batch_size=1)
73
 
74
  arr = []
75
- for res in output:
76
- polys = res['dt_polys']
77
- if polys is not None:
78
- arr.extend(polys.tolist())
79
 
80
  # Sort the bounding boxes in reading order
81
- arr = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
82
 
83
- if not arr:
84
  print("No text regions detected.")
85
  return "No text regions detected."
86
 
87
  cropped_images = []
88
- for box in arr:
89
  box = np.array(box, dtype=np.float32)
90
  width_a = np.linalg.norm(box[0] - box[1])
91
  width_b = np.linalg.norm(box[2] - box[3])
@@ -110,9 +108,7 @@ def ocr_pipeline(image: Image.Image) -> str:
110
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
111
 
112
  # Use CLIP to detect language. The model is already on the GPU.
113
- inputs = processor(text=candidates, images=pil_img, return_tensors="pt", padding=True)
114
- # Move inputs to the GPU
115
- inputs = {k: v.to(clip_model.device) for k, v in inputs.items()}
116
  with torch.no_grad():
117
  outputs = clip_model(**inputs)
118
  logits_per_image = outputs.logits_per_image
@@ -124,17 +120,16 @@ def ocr_pipeline(image: Image.Image) -> str:
124
  lang_code = lang_map.get(detected_lang, "en")
125
 
126
  # Initialize PaddleOCR with the detected language.
127
- # REMOVED the 'use_gpu=True' argument here as well.
128
- ocr = PaddleOCR(lang=lang_code, use_angle_cls=False, use_doc_unwarping=False)
129
  result = ocr.predict(img)
130
 
131
  # Extract text from OCR result
132
  text_for_this_image = ""
133
- if result and result[0] and result[0].get('rec_texts'):
134
  text_for_this_image = " ".join(result[0]['rec_texts'])
135
 
136
  # Store text and bounding box information
137
- center_x, center_y = get_box_center(arr[i])
138
  all_text_blocks.append({
139
  "text": text_for_this_image,
140
  "center_x": center_x,
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import CLIPProcessor, CLIPModel
4
  from paddleocr import PaddleOCR, TextDetection
5
  from PIL import Image
 
6
  import numpy as np
7
  import cv2
 
8
  import spaces
9
 
10
  # --- Global setup for models and data ---
11
+ print("🔄 Initializing models...")
 
12
 
13
+ # Check for GPU and set device
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Device being used: {device}")
16
+
17
+ # Load CLIP model once. This is memory-intensive, so we do it once.
18
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
19
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
20
 
21
  # Initialize Paddle's text detection model.
22
+ # The latest versions of PaddlePaddle/PaddleOCR automatically use the GPU.
23
  det_model = TextDetection(model_name="PP-OCRv5_server_det")
24
 
25
  # Candidate language phrases for detection
 
28
  "This is Telugu text",
29
  "This is Chinese text",
30
  "This is Korean text",
 
31
  ]
32
 
33
  # Map detected languages to PaddleOCR language codes
 
38
  "korean": "korean",
39
  }
40
 
41
+ print("✅ Models loaded successfully.")
42
+
43
  # --- Utility Functions ---
44
  def get_box_center(box):
45
  """Calculates the center of a bounding box."""
 
49
  center_y = sum(y_coords) / len(y_coords)
50
  return center_x, center_y
51
 
52
+ @spaces.GPU
53
+ def ocr_pipeline(image_pil: Image.Image) -> str:
 
54
  """
55
  Performs OCR on an input image using a multi-step pipeline.
56
 
57
  Args:
58
+ image_pil: A PIL Image object from the Gradio interface.
59
 
60
  Returns:
61
  A string containing the reconstructed text.
62
  """
63
+ if image_pil is None:
64
  return "No image provided."
65
 
66
  print("Starting OCR pipeline...")
67
 
68
  # Convert PIL image to a NumPy array for OpenCV and Paddle
69
+ img_np = np.array(image_pil.convert("RGB"))
70
 
71
  # Step 1: Text Detection with PaddleOCR's model
 
72
  output = det_model.predict(img_np, batch_size=1)
73
 
74
  arr = []
75
+ if output and output[0] and 'dt_polys' in output[0] and output[0]['dt_polys'] is not None:
76
+ arr.extend(output[0]['dt_polys'].tolist())
 
 
77
 
78
  # Sort the bounding boxes in reading order
79
+ sorted_polys = sorted(arr, key=lambda box: (box[0][1], box[0][0]))
80
 
81
+ if not sorted_polys:
82
  print("No text regions detected.")
83
  return "No text regions detected."
84
 
85
  cropped_images = []
86
+ for box in sorted_polys:
87
  box = np.array(box, dtype=np.float32)
88
  width_a = np.linalg.norm(box[0] - box[1])
89
  width_b = np.linalg.norm(box[2] - box[3])
 
108
  pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
109
 
110
  # Use CLIP to detect language. The model is already on the GPU.
111
+ inputs = processor(text=candidates, images=pil_img, return_tensors="pt", padding=True).to(device)
 
 
112
  with torch.no_grad():
113
  outputs = clip_model(**inputs)
114
  logits_per_image = outputs.logits_per_image
 
120
  lang_code = lang_map.get(detected_lang, "en")
121
 
122
  # Initialize PaddleOCR with the detected language.
123
+ ocr = PaddleOCR(lang=lang_code, use_angle_cls=False, use_doc_unwarping=False, use_gpu=True)
 
124
  result = ocr.predict(img)
125
 
126
  # Extract text from OCR result
127
  text_for_this_image = ""
128
+ if result and result[0] and 'rec_texts' in result[0]:
129
  text_for_this_image = " ".join(result[0]['rec_texts'])
130
 
131
  # Store text and bounding box information
132
+ center_x, center_y = get_box_center(sorted_polys[i])
133
  all_text_blocks.append({
134
  "text": text_for_this_image,
135
  "center_x": center_x,