imperiusrex commited on
Commit
ab94877
·
verified ·
1 Parent(s): 700041b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -6,50 +6,48 @@ import numpy as np
6
  import cv2
7
  from paddleocr import TextDetection
8
 
9
- MODEL_HUB_ID = "imperiusrex/Handwritten_model" # <--- MAKE SURE THIS IS CORRECT
 
 
 
 
 
 
 
10
 
11
  processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID)
12
  model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID)
13
-
14
- # Move model to appropriate device (GPU if available, else CPU)
15
- model.eval()
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
 
18
 
19
  ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det")
20
 
21
- # --- Inference Function for Gradio ---
 
 
22
  def recognize_handwritten_text(image_input):
23
  if image_input is None:
24
  return "Please upload an image."
25
 
26
- # Convert Gradio image input (numpy array) to PIL Image
27
  image_pil = Image.fromarray(image_input).convert("RGB")
28
 
29
- # Perform text detection with PaddleOCR
30
- # PaddleOCR expects a file path or numpy array
31
  detection_results = ocr_det_model.predict(image_input, batch_size=1)
32
 
33
  detected_polys = []
34
  for res in detection_results:
35
- polys = res['dt_polys']
36
  if polys is not None:
37
  detected_polys.extend(polys.tolist())
38
 
39
  cropped_images = []
40
  if detected_polys:
41
- img_np = np.array(image_pil) # Convert PIL to NumPy for OpenCV
42
 
43
- for i, box in enumerate(detected_polys):
44
  box = np.array(box, dtype=np.float32)
45
 
46
- width_a = np.linalg.norm(box[0] - box[1])
47
- width_b = np.linalg.norm(box[2] - box[3])
48
- height_a = np.linalg.norm(box[0] - box[3])
49
- height_b = np.linalg.norm(box[1] - box[2])
50
-
51
- width = int(max(width_a, width_b))
52
- height = int(max(height_a, height_b))
53
 
54
  dst_rect = np.array([
55
  [0, 0],
@@ -60,9 +58,9 @@ def recognize_handwritten_text(image_input):
60
 
61
  M = cv2.getPerspectiveTransform(box, dst_rect)
62
  warped = cv2.warpPerspective(img_np, M, (width, height))
63
- cropped_images.append(Image.fromarray(warped).convert("RGB")) # Convert back to PIL
64
 
65
- cropped_images.reverse() # Apply reverse if that was intended based on your original code
66
 
67
  recognized_texts = []
68
  if cropped_images:
@@ -73,23 +71,25 @@ def recognize_handwritten_text(image_input):
73
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
74
  recognized_texts.append(generated_text)
75
  else:
76
- # Fallback if no text detected by PaddleOCR - process the whole image
77
  pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
78
  with torch.no_grad():
79
  generated_ids = model.generate(pixel_values, max_new_tokens=64)
80
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
81
- recognized_texts.append("No specific text regions detected, processing full image: " + generated_text)
82
 
83
  return "\n".join(recognized_texts)
84
 
85
- # --- Gradio Interface Setup ---
86
- iface = gr.Interface(
87
- fn=recognize_handwritten_text,
88
- inputs=gr.Image(type="numpy", label="Upload Handwritten Image"),
89
- outputs="text",
90
- title="Handwritten Text Recognition with TrOCR and PaddleOCR",
91
- description="Upload an image with handwritten text to get it recognized. Uses PaddleOCR for text detection and TrOCR for recognition."
92
- )
93
-
94
-
95
- iface.launch()
 
 
 
 
6
  import cv2
7
  from paddleocr import TextDetection
8
 
9
+ # --- Constants ---
10
+ MODEL_HUB_ID = "imperiusrex/Handwritten_model"
11
+
12
+ # --- Device ---
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # --- Load Models Globally ---
16
+ print("🔄 Loading models...")
17
 
18
  processor = TrOCRProcessor.from_pretrained(MODEL_HUB_ID)
19
  model = VisionEncoderDecoderModel.from_pretrained(MODEL_HUB_ID)
 
 
 
 
20
  model.to(device)
21
+ model.eval()
22
 
23
  ocr_det_model = TextDetection(model_name="PP-OCRv5_server_det")
24
 
25
+ print("✅ Models loaded successfully.")
26
+
27
+ # --- Inference Function ---
28
  def recognize_handwritten_text(image_input):
29
  if image_input is None:
30
  return "Please upload an image."
31
 
 
32
  image_pil = Image.fromarray(image_input).convert("RGB")
33
 
 
 
34
  detection_results = ocr_det_model.predict(image_input, batch_size=1)
35
 
36
  detected_polys = []
37
  for res in detection_results:
38
+ polys = res.get('dt_polys', [])
39
  if polys is not None:
40
  detected_polys.extend(polys.tolist())
41
 
42
  cropped_images = []
43
  if detected_polys:
44
+ img_np = np.array(image_pil)
45
 
46
+ for box in detected_polys:
47
  box = np.array(box, dtype=np.float32)
48
 
49
+ width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3])))
50
+ height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2])))
 
 
 
 
 
51
 
52
  dst_rect = np.array([
53
  [0, 0],
 
58
 
59
  M = cv2.getPerspectiveTransform(box, dst_rect)
60
  warped = cv2.warpPerspective(img_np, M, (width, height))
61
+ cropped_images.append(Image.fromarray(warped).convert("RGB"))
62
 
63
+ cropped_images.reverse()
64
 
65
  recognized_texts = []
66
  if cropped_images:
 
71
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
  recognized_texts.append(generated_text)
73
  else:
 
74
  pixel_values = processor(images=image_pil, return_tensors="pt").pixel_values.to(device)
75
  with torch.no_grad():
76
  generated_ids = model.generate(pixel_values, max_new_tokens=64)
77
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
78
+ recognized_texts.append("No text boxes detected. Full image OCR:\n" + generated_text)
79
 
80
  return "\n".join(recognized_texts)
81
 
82
+ # --- Gradio Interface ---
83
+ def build_interface():
84
+ return gr.Interface(
85
+ fn=recognize_handwritten_text,
86
+ inputs=gr.Image(type="numpy", label="Upload Handwritten Image"),
87
+ outputs="text",
88
+ title="✍️ Handwritten Text Recognition",
89
+ description="📷 Upload a handwritten image. Uses PaddleOCR (detection) + TrOCR (recognition).",
90
+ )
91
+
92
+ # --- Launch App ---
93
+ if __name__ == "__main__":
94
+ iface = build_interface()
95
+ iface.launch()