Ehtesham123 commited on
Commit
5b19d10
·
verified ·
1 Parent(s): 8fe1e58

Upload 2 files

Browse files
Files changed (2) hide show
  1. STD_detect.py +46 -0
  2. STR_recognize.py +18 -0
STD_detect.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from ultralytics import YOLO
4
+ from PIL import Image
5
+
6
+ class OBBPredictor:
7
+ def __init__(self, model_path):
8
+ self.model = YOLO(model_path)
9
+
10
+ @staticmethod
11
+ def order_points(pts):
12
+ rect = np.zeros((4, 2), dtype=np.float32)
13
+ s = pts.sum(axis=1)
14
+ rect[0] = pts[np.argmin(s)] # top-left
15
+ rect[2] = pts[np.argmax(s)] # bottom-right
16
+ diff = np.diff(pts, axis=1)
17
+ rect[1] = pts[np.argmin(diff)] # top-right
18
+ rect[3] = pts[np.argmax(diff)] # bottom-left
19
+ return rect
20
+
21
+ @staticmethod
22
+ def crop_obb_region(image, points):
23
+ ordered_pts = OBBPredictor.order_points(points).astype(np.float32)
24
+ width = int(max(np.linalg.norm(ordered_pts[0] - ordered_pts[1]),
25
+ np.linalg.norm(ordered_pts[2] - ordered_pts[3])))
26
+ height = int(max(np.linalg.norm(ordered_pts[1] - ordered_pts[2]),
27
+ np.linalg.norm(ordered_pts[3] - ordered_pts[0])))
28
+ dst_pts = np.array([
29
+ [0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]
30
+ ], dtype=np.float32)
31
+
32
+ M = cv2.getPerspectiveTransform(ordered_pts, dst_pts)
33
+ warped = cv2.warpPerspective(image, M, (width, height))
34
+ return Image.fromarray(warped)
35
+
36
+ def predict(self, image_pil):
37
+ image_np = np.array(image_pil)
38
+ results = self.model(image_np)
39
+ crops = []
40
+ for result in results:
41
+ if hasattr(result.obb, "xyxyxyxy") and len(result.obb.xyxyxyxy) > 0:
42
+ for box in result.obb.xyxyxyxy:
43
+ points = box.cpu().numpy()
44
+ cropped = self.crop_obb_region(image_np, points)
45
+ crops.append(cropped)
46
+ return crops
STR_recognize.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from strhub.data.module import SceneTextDataModule
4
+ from strhub.models.utils import load_from_checkpoint
5
+
6
+ class TextRecognizer:
7
+ def __init__(self, ckpt_path, device='cpu'):
8
+ self.device = device
9
+ self.str = load_from_checkpoint(ckpt_path).eval().to(device)
10
+ self.img_transform = SceneTextDataModule.get_transform(self.str.hparams.img_size)
11
+
12
+ def recognize(self, image_pil):
13
+ image_tensor = self.img_transform(image_pil).unsqueeze(0).to(self.device)
14
+ with torch.no_grad():
15
+ logits = self.str(image_tensor)
16
+ pred = logits.softmax(-1)
17
+ label, _ = self.str.tokenizer.decode(pred)
18
+ return label[0]