MedhaCodes commited on
Commit
8b1b3a9
·
verified ·
1 Parent(s): 065e49f

Update app/infer.py

Browse files
Files changed (1) hide show
  1. app/infer.py +31 -63
app/infer.py CHANGED
@@ -4,31 +4,14 @@ import torch
4
  import torch.nn as nn
5
  from PIL import Image
6
  import torchvision.transforms as T
 
7
 
8
- # =====================
9
- # DEVICE
10
- # =====================
11
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- # =====================
14
- # CHARSET (EXACT MATCH)
15
- # =====================
16
- import string
17
-
18
- DIGITS = string.digits
19
- LOWER = string.ascii_lowercase
20
- UPPER = string.ascii_uppercase
21
-
22
- BLANK_CHAR = "-"
23
- CHARS = DIGITS + LOWER + UPPER + BLANK_CHAR
24
-
25
- char2idx = {c: i for i, c in enumerate(CHARS)}
26
- idx2char = {i: c for c, i in char2idx.items()}
27
  NUM_CLASSES = len(CHARS)
28
 
29
- # =====================
30
- # CRNN (EXACT SAME MODEL)
31
- # =====================
32
  class CRNN(nn.Module):
33
  def __init__(self):
34
  super().__init__()
@@ -63,68 +46,53 @@ class CRNN(nn.Module):
63
  x, _ = self.rnn(x)
64
  return self.fc(x)
65
 
66
- # =====================
67
  # LOAD MODEL
68
- # =====================
69
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
70
- WEIGHTS_PATH = os.path.join(BASE_DIR, "weights", "ocr_model.pth")
71
-
72
- if not os.path.exists(WEIGHTS_PATH):
73
- raise FileNotFoundError(f"Missing model: {WEIGHTS_PATH}")
74
 
75
  model = CRNN().to(DEVICE)
76
- model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
77
  model.eval()
78
 
79
- # =====================
80
- # IMAGE TRANSFORM (60×160)
81
- # =====================
82
  transform = T.Compose([
83
  T.Grayscale(),
84
  T.Resize((60, 160)),
85
  T.ToTensor()
86
  ])
87
 
88
- # =====================
89
- # BEAM SEARCH
90
- # =====================
91
- def ctc_beam_search(logits, beam_width=5):
92
- probs = logits.softmax(2)
93
- T, C = probs.shape[1], probs.shape[2]
94
 
95
- beams = [("", 1.0)]
 
96
 
97
- for t in range(T):
98
- new_beams = {}
99
- for prefix, score in beams:
100
- for c in range(C):
101
- p = probs[0, t, c].item()
102
- if p < 1e-4:
103
- continue
104
 
105
- char = idx2char[c]
106
- new_prefix = prefix if char == BLANK_CHAR else prefix + char
107
- new_beams[new_prefix] = max(
108
- new_beams.get(new_prefix, 0.0),
109
- score * p
110
- )
111
 
112
- beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]
113
-
114
- return beams
115
-
116
- def decode_with_confidence(logits):
117
- text, score = ctc_beam_search(logits)[0]
118
- return text, round(min(1.0, score * 10), 3)
119
-
120
- # =====================
121
  # PUBLIC API
122
- # =====================
123
  def predict(pil_img: Image.Image):
124
  img = transform(pil_img).unsqueeze(0).to(DEVICE)
125
 
126
  with torch.no_grad():
127
  logits = model(img)
128
 
129
- text, conf = decode_with_confidence(logits)
130
- return {"text": text, "confidence": conf}
 
 
 
4
  import torch.nn as nn
5
  from PIL import Image
6
  import torchvision.transforms as T
7
+ from app.utils import CHARS, idx2char, BLANK_CHAR
8
 
 
 
 
9
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  NUM_CLASSES = len(CHARS)
11
 
12
+ # --------------------
13
+ # CRNN MODEL (SAME AS TRAINING)
14
+ # --------------------
15
  class CRNN(nn.Module):
16
  def __init__(self):
17
  super().__init__()
 
46
  x, _ = self.rnn(x)
47
  return self.fc(x)
48
 
49
+ # --------------------
50
  # LOAD MODEL
51
+ # --------------------
52
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
53
+ WEIGHTS = os.path.join(BASE_DIR, "weights", "ocr_model.pth")
 
 
 
54
 
55
  model = CRNN().to(DEVICE)
56
+ model.load_state_dict(torch.load(WEIGHTS, map_location=DEVICE))
57
  model.eval()
58
 
59
+ # --------------------
60
+ # TRANSFORM
61
+ # --------------------
62
  transform = T.Compose([
63
  T.Grayscale(),
64
  T.Resize((60, 160)),
65
  T.ToTensor()
66
  ])
67
 
68
+ # --------------------
69
+ # CTC DECODER
70
+ # --------------------
71
+ def ctc_decode(logits):
72
+ probs = logits.softmax(2)[0]
73
+ best = probs.argmax(1)
74
 
75
+ prev = None
76
+ text = ""
77
 
78
+ for idx in best:
79
+ idx = idx.item()
80
+ if idx != prev and CHARS[idx] != BLANK_CHAR:
81
+ text += CHARS[idx]
82
+ prev = idx
 
 
83
 
84
+ return text
 
 
 
 
 
85
 
86
+ # --------------------
 
 
 
 
 
 
 
 
87
  # PUBLIC API
88
+ # --------------------
89
  def predict(pil_img: Image.Image):
90
  img = transform(pil_img).unsqueeze(0).to(DEVICE)
91
 
92
  with torch.no_grad():
93
  logits = model(img)
94
 
95
+ text = ctc_decode(logits)
96
+ confidence = round(float(logits.softmax(2).max()), 3)
97
+
98
+ return text, confidence