MedhaCodes commited on
Commit
06bbcb9
·
verified ·
1 Parent(s): 522c19f

Update app/infer.py

Browse files
Files changed (1) hide show
  1. app/infer.py +111 -40
app/infer.py CHANGED
@@ -1,59 +1,130 @@
1
- import cv2
2
- import torch
3
- import numpy as np
4
  import os
 
 
 
 
5
 
6
- from app.model import CRNN
7
- from app.utils import decode_with_confidence
8
-
9
- # ----------------------------
10
- # Device
11
- # ----------------------------
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # ----------------------------
15
- # Load model
16
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
  WEIGHTS_PATH = os.path.join(BASE_DIR, "weights", "ocr_model.pth")
19
 
20
- model = CRNN()
 
 
 
21
  model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
22
- model.to(DEVICE)
23
  model.eval()
24
 
 
 
 
 
 
 
 
 
25
 
26
- # ----------------------------
27
- # Image preprocessing
28
- # ----------------------------
29
- def preprocess(image_bytes):
30
- img = cv2.imdecode(
31
- np.frombuffer(image_bytes, np.uint8),
32
- cv2.IMREAD_GRAYSCALE
33
- )
34
 
35
- if img is None:
36
- raise ValueError("Invalid image")
37
 
38
- img = cv2.resize(img, (160, 60))
39
- img = img.astype("float32") / 255.0
 
 
 
 
 
40
 
41
- tensor = torch.tensor(img).unsqueeze(0).unsqueeze(0)
42
- return tensor.to(DEVICE)
 
 
 
 
43
 
 
44
 
45
- # ----------------------------
46
- # Prediction entry point
47
- # ----------------------------
48
- def predict(image_bytes):
49
- x = preprocess(image_bytes)
50
 
51
- with torch.no_grad():
52
- logits = model(x)
 
53
 
54
- text, confidence = decode_with_confidence(logits)
 
 
 
 
 
 
 
55
 
56
- return {
57
- "text": text,
58
- "confidence": confidence
59
- }
 
1
+ # app/infer.py
 
 
2
  import os
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
 
8
+ # =====================
9
+ # DEVICE
10
+ # =====================
 
 
 
11
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # =====================
14
+ # CHARSET (EXACT MATCH)
15
+ # =====================
16
+ import string
17
+
18
+ DIGITS = string.digits
19
+ LOWER = string.ascii_lowercase
20
+ UPPER = string.ascii_uppercase
21
+
22
+ BLANK_CHAR = "-"
23
+ CHARS = DIGITS + LOWER + UPPER + BLANK_CHAR
24
+
25
+ char2idx = {c: i for i, c in enumerate(CHARS)}
26
+ idx2char = {i: c for c, i in char2idx.items()}
27
+ NUM_CLASSES = len(CHARS)
28
+
29
+ # =====================
30
+ # CRNN (EXACT SAME MODEL)
31
+ # =====================
32
+ class CRNN(nn.Module):
33
+ def __init__(self):
34
+ super().__init__()
35
+
36
+ self.cnn = nn.Sequential(
37
+ nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
38
+ nn.MaxPool2d(2, 2),
39
+
40
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
41
+ nn.MaxPool2d(2, 2),
42
+
43
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
44
+ nn.MaxPool2d((2, 1)),
45
+
46
+ nn.Conv2d(256, 256, 3, padding=1), nn.ReLU()
47
+ )
48
+
49
+ self.rnn = nn.LSTM(
50
+ input_size=256 * 7,
51
+ hidden_size=256,
52
+ num_layers=2,
53
+ bidirectional=True,
54
+ batch_first=True
55
+ )
56
+
57
+ self.fc = nn.Linear(512, NUM_CLASSES)
58
+
59
+ def forward(self, x):
60
+ x = self.cnn(x)
61
+ b, c, h, w = x.shape
62
+ x = x.permute(0, 3, 1, 2).reshape(b, w, c * h)
63
+ x, _ = self.rnn(x)
64
+ return self.fc(x)
65
+
66
+ # =====================
67
+ # LOAD MODEL
68
+ # =====================
69
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
70
  WEIGHTS_PATH = os.path.join(BASE_DIR, "weights", "ocr_model.pth")
71
 
72
+ if not os.path.exists(WEIGHTS_PATH):
73
+ raise FileNotFoundError(f"Missing model: {WEIGHTS_PATH}")
74
+
75
+ model = CRNN().to(DEVICE)
76
  model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
 
77
  model.eval()
78
 
79
+ # =====================
80
+ # IMAGE TRANSFORM (60×160)
81
+ # =====================
82
+ transform = T.Compose([
83
+ T.Grayscale(),
84
+ T.Resize((60, 160)),
85
+ T.ToTensor()
86
+ ])
87
 
88
+ # =====================
89
+ # BEAM SEARCH
90
+ # =====================
91
+ def ctc_beam_search(logits, beam_width=5):
92
+ probs = logits.softmax(2)
93
+ T, C = probs.shape[1], probs.shape[2]
 
 
94
 
95
+ beams = [("", 1.0)]
 
96
 
97
+ for t in range(T):
98
+ new_beams = {}
99
+ for prefix, score in beams:
100
+ for c in range(C):
101
+ p = probs[0, t, c].item()
102
+ if p < 1e-4:
103
+ continue
104
 
105
+ char = idx2char[c]
106
+ new_prefix = prefix if char == BLANK_CHAR else prefix + char
107
+ new_beams[new_prefix] = max(
108
+ new_beams.get(new_prefix, 0.0),
109
+ score * p
110
+ )
111
 
112
+ beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]
113
 
114
+ return beams
 
 
 
 
115
 
116
+ def decode_with_confidence(logits):
117
+ text, score = ctc_beam_search(logits)[0]
118
+ return text, round(min(1.0, score * 10), 3)
119
 
120
+ # =====================
121
+ # PUBLIC API
122
+ # =====================
123
+ def predict(pil_img: Image.Image):
124
+ img = transform(pil_img).unsqueeze(0).to(DEVICE)
125
+
126
+ with torch.no_grad():
127
+ logits = model(img)
128
 
129
+ text, conf = decode_with_confidence(logits)
130
+ return {"text": text, "confidence": conf}