Angstormy commited on
Commit
7b2d644
·
verified ·
1 Parent(s): f71d7e4

Upload 5 Files

Browse files
Files changed (5) hide show
  1. Dockerfile +19 -0
  2. api.py +197 -0
  3. best_model_20k.pt +3 -0
  4. requirements.txt +10 -0
  5. vocab.json +103 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install minimal system dependencies
6
+ RUN apt-get update && apt-get install -y libglib2.0-0 && rm -rf /var/lib/apt/lists/*
7
+
8
+ # Install Python requirements
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ # Copy all the model files and api.py
13
+ COPY . .
14
+
15
+ # Expose port 7860 (Hugging Face Spaces default port)
16
+ EXPOSE 7860
17
+
18
+ # Start the FastAPI server using uvicorn
19
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as models
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image, ImageOps
8
+ import json
9
+ import io
10
+ import os
11
+ import cv2
12
+ import numpy as np
13
+ import base64
14
+ import math
15
+
16
+ app = FastAPI()
17
+
18
+ # Allow CORS for React development
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # --- Model Architecture (LITERAL SYNC from Training3.ipynb) ---
28
+
29
+ class PositionalEncoding1D(nn.Module):
30
+ def __init__(self, d_model, max_len=512):
31
+ super().__init__()
32
+ pe = torch.zeros(max_len, d_model)
33
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
34
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
35
+ pe[:, 0::2] = torch.sin(position * div_term)
36
+ pe[:, 1::2] = torch.cos(position * div_term)
37
+ self.register_buffer('pe', pe.unsqueeze(0))
38
+
39
+ def forward(self, x):
40
+ return x + self.pe[:, :x.size(1)]
41
+
42
+ class OCRModel(nn.Module):
43
+ def __init__(self, vocab_size):
44
+ super().__init__()
45
+ # EXACT LAYER NAMES FROM CHECKPOINT
46
+ resnet = models.resnet34(weights=None)
47
+ self.encoder = nn.Sequential(*list(resnet.children())[:-2])
48
+ self.enc_proj = nn.Conv2d(512, 256, 1)
49
+
50
+ self.token_embed = nn.Embedding(vocab_size, 256)
51
+ self.pos_decoder = PositionalEncoding1D(256)
52
+
53
+ decoder_layer = nn.TransformerDecoderLayer(d_model=256, nhead=4, dim_feedforward=1024, batch_first=True)
54
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=4)
55
+
56
+ self.output_layer = nn.Linear(256, vocab_size)
57
+
58
+ def forward(self, images, tgt):
59
+ feat = self.encoder(images)
60
+ feat = self.enc_proj(feat)
61
+ memory = feat.flatten(2).permute(0, 2, 1)
62
+
63
+ tgt = self.token_embed(tgt)
64
+ tgt = self.pos_decoder(tgt)
65
+
66
+ mask = torch.triu(torch.ones(tgt.size(1), tgt.size(1), device=tgt.device), 1).bool()
67
+ out = self.decoder(tgt, memory, tgt_mask=mask)
68
+ return self.output_layer(out)
69
+
70
+ def fuzzy_correct(text):
71
+ """Pass-through: Identity logic for literal model verification."""
72
+ return text
73
+
74
+ # --- Global Resources ---
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ model = None
77
+ stoi = None
78
+ itos = None
79
+
80
+ def load_resources():
81
+ global model, stoi, itos
82
+ model_path = "best_model_20k.pt"
83
+ vocab_path = "vocab.json"
84
+
85
+ if os.path.exists(model_path):
86
+ checkpoint = torch.load(model_path, map_location=device)
87
+ if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
88
+ state_dict = checkpoint["model_state_dict"]
89
+ stoi = checkpoint["stoi"]
90
+ itos = {int(k): v for k, v in checkpoint["itos"].items()}
91
+ else:
92
+ state_dict = checkpoint
93
+ if os.path.exists(vocab_path):
94
+ with open(vocab_path, "r", encoding="utf-8") as f:
95
+ vdata = json.load(f)
96
+ stoi = vdata["stoi"]
97
+ itos = {int(k): v for k, v in vdata["itos"].items()}
98
+
99
+ vocab_size = len(stoi)
100
+ model = OCRModel(vocab_size).to(device)
101
+ # STRICT=TRUE IS NOW ENABLED
102
+ model.load_state_dict(state_dict, strict=True)
103
+ model.eval()
104
+ print(f"✅ Checkpoint mapped perfectly to brain memory ({vocab_size} classes).")
105
+
106
+ def preprocess_image(image_bytes):
107
+ # 1. Load with OpenCV (for smart-focus extraction)
108
+ nparr = np.frombuffer(image_bytes, np.uint8)
109
+ img_gray = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
110
+
111
+ # 2. Find the Word bounding box (Otsu Binarization)
112
+ _, thresh = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
113
+ coords = cv2.findNonZero(thresh)
114
+
115
+ if coords is not None:
116
+ x, y, w, h = cv2.boundingRect(coords)
117
+ # 5% padding gives the best raw convolution feature mapping
118
+ pad_x, pad_y = int(w * 0.05), int(h * 0.05)
119
+ y_max, x_max = img_gray.shape
120
+ x1, y1 = max(0, x - pad_x), max(0, y - pad_y)
121
+ x2, y2 = min(x_max, x + w + pad_x), min(y_max, y + h + pad_y)
122
+ img_cropped = img_gray[y1:y2, x1:x2]
123
+ else:
124
+ img_cropped = img_gray
125
+
126
+ # 3. LITERAL NOTEBOOK TRANSFORM (Applied on Focused Word)
127
+ pil_img = Image.fromarray(img_cropped).convert("L")
128
+
129
+ # Resize exactly as notebook (IMG_HEIGHT=48, MAX_WIDTH=160)
130
+ pil_img = pil_img.resize((160, 48), Image.BILINEAR)
131
+
132
+ # 3-Channel Grayscale (Exact Notebook Sync)
133
+ pil_img = transforms.Compose([
134
+ transforms.Grayscale(num_output_channels=3),
135
+ transforms.ToTensor()
136
+ ])(pil_img)
137
+
138
+ img_tensor = pil_img.unsqueeze(0).to(device)
139
+
140
+ # UI View (Debug log)
141
+ debug_arr = (img_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
142
+ debug_arr = cv2.cvtColor(debug_arr, cv2.COLOR_RGB2BGR)
143
+ _, buffer = cv2.imencode('.png', debug_arr)
144
+ debug_b64 = base64.b64encode(buffer).decode('utf-8')
145
+
146
+ return img_tensor, debug_b64
147
+
148
+ @app.on_event("startup")
149
+ async def startup_event():
150
+ load_resources()
151
+
152
+ def greedy_decode(model, images, max_len=25):
153
+ """Refined Greedy Decode with special token verification."""
154
+ B = images.size(0)
155
+ BOS_VAL = stoi.get("<bos>", 1)
156
+ EOS_VAL = stoi.get("<eos>", 2)
157
+ PAD_VAL = stoi.get("<pad>", 0)
158
+
159
+ decoded = torch.full((B, 1), BOS_VAL, dtype=torch.long, device=device)
160
+
161
+ for _ in range(max_len):
162
+ with torch.cuda.amp.autocast():
163
+ logits = model(images, decoded)
164
+ next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
165
+ decoded = torch.cat([decoded, next_token], dim=1)
166
+ if next_token.item() == EOS_VAL: break
167
+
168
+ ids = decoded[0].tolist()
169
+ # Decode string exactly as notebook does
170
+ out = []
171
+ for i in ids:
172
+ if i == EOS_VAL: break
173
+ if i in [PAD_VAL, BOS_VAL]: continue
174
+ out.append(itos.get(i, ""))
175
+ return "".join(out)
176
+
177
+ @app.post("/predict")
178
+ async def predict_ocr(file: UploadFile = File(...)):
179
+ if model is None: return {"error": "Model not loaded"}
180
+ try:
181
+ image_bytes = await file.read()
182
+ images, debug_b64 = preprocess_image(image_bytes)
183
+ prediction = greedy_decode(model, images)
184
+ final_prediction = fuzzy_correct(prediction)
185
+ print(f"RECOGNIZED: '{prediction}' -> FINAL: '{final_prediction}'")
186
+ return {
187
+ "prediction": final_prediction,
188
+ "engine_view": f"data:image/png;base64,{debug_b64}"
189
+ }
190
+ except Exception as e:
191
+ import traceback
192
+ traceback.print_exc()
193
+ return {"error": str(e)}
194
+
195
+ if __name__ == "__main__":
196
+ import uvicorn
197
+ uvicorn.run(app, host="0.0.0.0", port=8000)
best_model_20k.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f808a21186c88003997e034616c6c8310ca9bbd5456830814b17c86c380d75b
3
+ size 309043646
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ --extra-index-url https://download.pytorch.org/whl/cpu
5
+ torch
6
+ torchvision
7
+ opencv-python-headless
8
+ numpy
9
+ pillow
10
+ pydantic
vocab.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "<pad>",
3
+ "<bos>",
4
+ "<eos>",
5
+ "<unk>",
6
+ "-",
7
+ "ँ",
8
+ "ं",
9
+ "ः",
10
+ "अ",
11
+ "आ",
12
+ "इ",
13
+ "ई",
14
+ "उ",
15
+ "ऊ",
16
+ "ऋ",
17
+ "ऌ",
18
+ "ऍ",
19
+ "ऎ",
20
+ "ए",
21
+ "ऐ",
22
+ "ऑ",
23
+ "ओ",
24
+ "औ",
25
+ "क",
26
+ "ख",
27
+ "ग",
28
+ "घ",
29
+ "ङ",
30
+ "च",
31
+ "छ",
32
+ "ज",
33
+ "झ",
34
+ "ञ",
35
+ "ट",
36
+ "ठ",
37
+ "ड",
38
+ "ढ",
39
+ "ण",
40
+ "त",
41
+ "थ",
42
+ "द",
43
+ "ध",
44
+ "न",
45
+ "ऩ",
46
+ "प",
47
+ "फ",
48
+ "ब",
49
+ "भ",
50
+ "म",
51
+ "य",
52
+ "र",
53
+ "ऱ",
54
+ "ल",
55
+ "ऴ",
56
+ "व",
57
+ "श",
58
+ "ष",
59
+ "स",
60
+ "ह",
61
+ "़",
62
+ "ऽ",
63
+ "ा",
64
+ "ि",
65
+ "ी",
66
+ "ु",
67
+ "ू",
68
+ "ृ",
69
+ "ॄ",
70
+ "ॅ",
71
+ "े",
72
+ "ै",
73
+ "ॉ",
74
+ "ॊ",
75
+ "ो",
76
+ "ौ",
77
+ "्",
78
+ "ॐ",
79
+ "॑",
80
+ "॒",
81
+ "॓",
82
+ "ॠ",
83
+ "ॢ",
84
+ "।",
85
+ "॥",
86
+ "०",
87
+ "१",
88
+ "२",
89
+ "३",
90
+ "४",
91
+ "५",
92
+ "६",
93
+ "७",
94
+ "८",
95
+ "९",
96
+ "॰",
97
+ "ॱ",
98
+ "ॲ",
99
+ "ॻ",
100
+ "ॼ",
101
+ "ॽ",
102
+ "ॾ"
103
+ ]