Vrajesharma commited on
Commit
3c3c15a
Β·
1 Parent(s): d2d676d

Live detection added

Browse files
Files changed (2) hide show
  1. app.py +169 -136
  2. requirements.txt +5 -5
app.py CHANGED
@@ -1,137 +1,170 @@
1
- import io
2
- import json
3
- import torch
4
- import torch.nn as nn
5
- from torchvision import models, transforms
6
- from PIL import Image
7
- from fastapi import FastAPI, File, UploadFile, HTTPException
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import JSONResponse
10
-
11
- # ─── App setup ──────────────────────────────────────────────────────────────
12
- app = FastAPI(title="ISL Recognition API", version="1.0.0")
13
-
14
- app.add_middleware(
15
- CORSMiddleware,
16
- allow_origins=["*"], # Lock this to your Vercel URL in production
17
- allow_credentials=True,
18
- allow_methods=["*"],
19
- allow_headers=["*"],
20
- )
21
-
22
- # ─── Model loader ────────────────────────────────────────────────────────────
23
- def build_model(arch: str, num_classes: int) -> nn.Module:
24
- arch = arch.lower()
25
- if arch == "resnet18":
26
- model = models.resnet18(weights=None)
27
- model.fc = nn.Sequential(
28
- nn.Dropout(0.5),
29
- nn.Linear(model.fc.in_features, num_classes)
30
- )
31
- elif arch == "mobilenet_v2":
32
- model = models.mobilenet_v2(weights=None)
33
- model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
34
- elif arch == "efficientnet_b0":
35
- model = models.efficientnet_b0(weights=None)
36
- model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
37
- elif arch == "vgg16":
38
- model = models.vgg16(weights=None)
39
- model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
40
- elif arch in ("cnn", "cnn_dropout"):
41
- # Simple custom CNN
42
- class _CNN(nn.Module):
43
- def __init__(self, n, dropout=False):
44
- super().__init__()
45
- self.features = nn.Sequential(
46
- nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2),
47
- nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2),
48
- nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2),
49
- nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d(2),
50
- )
51
- layers = [nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()]
52
- if dropout:
53
- layers.append(nn.Dropout(0.5))
54
- layers.append(nn.Linear(256, n))
55
- self.classifier = nn.Sequential(*layers)
56
- def forward(self, x):
57
- return self.classifier(self.features(x))
58
- model = _CNN(num_classes, dropout=(arch == "cnn_dropout"))
59
- else:
60
- raise ValueError(f"Unknown architecture: {arch}")
61
- return model
62
-
63
-
64
- # ─── Load checkpoint on startup ──────────────────────────────────────────────
65
- MODEL_PATH = "isl_best_model.pth"
66
- device = torch.device("cpu")
67
-
68
- checkpoint = torch.load(MODEL_PATH, map_location=device)
69
- ARCH = checkpoint["arch"]
70
- NUM_CLASSES = checkpoint["num_classes"]
71
- CLASS_NAMES = checkpoint["class_names"]
72
-
73
- model = build_model(ARCH, NUM_CLASSES)
74
- model.load_state_dict(checkpoint["state_dict"])
75
- model.eval()
76
- model.to(device)
77
-
78
- print(f"βœ… Loaded model: {ARCH} | Classes: {NUM_CLASSES} | Val Acc: {checkpoint.get('val_acc', 'N/A')}")
79
-
80
- # ─── Inference transform (matches val_transform in notebook) ─────────────────
81
- transform = transforms.Compose([
82
- transforms.Resize((224, 224)),
83
- transforms.ToTensor(),
84
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
85
- ])
86
-
87
-
88
- # ─── Routes ──────────────────────────────────────────────────────────────────
89
- @app.get("/")
90
- def root():
91
- return {
92
- "message": "ISL Recognition API is running 🀟",
93
- "model": ARCH,
94
- "num_classes": NUM_CLASSES,
95
- "val_acc": checkpoint.get("val_acc"),
96
- }
97
-
98
- @app.get("/health")
99
- def health():
100
- return {"status": "ok"}
101
-
102
- @app.get("/classes")
103
- def get_classes():
104
- return {"classes": CLASS_NAMES}
105
-
106
- @app.post("/predict")
107
- async def predict(file: UploadFile = File(...)):
108
- # Validate file type
109
- if file.content_type not in ("image/jpeg", "image/png", "image/jpg", "image/webp"):
110
- raise HTTPException(status_code=400, detail="Only JPEG/PNG images are accepted.")
111
-
112
- try:
113
- contents = await file.read()
114
- image = Image.open(io.BytesIO(contents)).convert("RGB")
115
- except Exception:
116
- raise HTTPException(status_code=400, detail="Could not read image file.")
117
-
118
- tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
119
-
120
- with torch.no_grad():
121
- logits = model(tensor)
122
- probs = torch.softmax(logits, dim=1)[0]
123
-
124
- top5_probs, top5_idx = torch.topk(probs, k=min(5, NUM_CLASSES))
125
-
126
- return JSONResponse({
127
- "prediction": CLASS_NAMES[top5_idx[0].item()],
128
- "confidence": round(top5_probs[0].item() * 100, 2),
129
- "top5": [
130
- {
131
- "label": CLASS_NAMES[idx.item()],
132
- "confidence": round(prob.item() * 100, 2)
133
- }
134
- for prob, idx in zip(top5_probs, top5_idx)
135
- ],
136
- "model_used": ARCH,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  })
 
1
+ import io
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models, transforms
6
+ from PIL import Image
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import JSONResponse
10
+
11
+ # ─── App setup ──────────────────────────────────────────────────────────────
12
+ app = FastAPI(title="ISL Recognition API", version="1.0.0")
13
+
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"], # Lock this to your Vercel URL in production
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ # ─── Model loader ────────────────────────────────────────────────────────────
23
+ def build_model(arch: str, num_classes: int) -> nn.Module:
24
+ arch = arch.lower()
25
+ if arch == "resnet18":
26
+ model = models.resnet18(weights=None)
27
+ model.fc = nn.Sequential(
28
+ nn.Dropout(0.5),
29
+ nn.Linear(model.fc.in_features, num_classes)
30
+ )
31
+ elif arch == "mobilenet_v2":
32
+ model = models.mobilenet_v2(weights=None)
33
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
34
+ elif arch == "efficientnet_b0":
35
+ model = models.efficientnet_b0(weights=None)
36
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
37
+ elif arch == "vgg16":
38
+ model = models.vgg16(weights=None)
39
+ model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
40
+ elif arch in ("cnn", "cnn_dropout"):
41
+ # Simple custom CNN
42
+ class _CNN(nn.Module):
43
+ def __init__(self, n, dropout=False):
44
+ super().__init__()
45
+ self.features = nn.Sequential(
46
+ nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2),
47
+ nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2),
48
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2),
49
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d(2),
50
+ )
51
+ layers = [nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()]
52
+ if dropout:
53
+ layers.append(nn.Dropout(0.5))
54
+ layers.append(nn.Linear(256, n))
55
+ self.classifier = nn.Sequential(*layers)
56
+ def forward(self, x):
57
+ return self.classifier(self.features(x))
58
+ model = _CNN(num_classes, dropout=(arch == "cnn_dropout"))
59
+ else:
60
+ raise ValueError(f"Unknown architecture: {arch}")
61
+ return model
62
+
63
+
64
+ # ─── Load checkpoint on startup ──────────────────────────────────────────────
65
+ MODEL_PATH = "isl_best_model.pth"
66
+ device = torch.device("cpu")
67
+
68
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
69
+ ARCH = checkpoint["arch"]
70
+ NUM_CLASSES = checkpoint["num_classes"]
71
+ CLASS_NAMES = checkpoint["class_names"]
72
+
73
+ model = build_model(ARCH, NUM_CLASSES)
74
+ model.load_state_dict(checkpoint["state_dict"])
75
+ model.eval()
76
+ model.to(device)
77
+
78
+ print(f"βœ… Loaded model: {ARCH} | Classes: {NUM_CLASSES} | Val Acc: {checkpoint.get('val_acc', 'N/A')}")
79
+
80
+ # ─── Inference transform (matches val_transform in notebook) ─────────────────
81
+ transform = transforms.Compose([
82
+ transforms.Resize((224, 224)),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
85
+ ])
86
+
87
+
88
+ # ─── Routes ──────────────────────────────────────────────────────────────────
89
+ @app.get("/")
90
+ def root():
91
+ return {
92
+ "message": "ISL Recognition API is running 🀟",
93
+ "model": ARCH,
94
+ "num_classes": NUM_CLASSES,
95
+ "val_acc": checkpoint.get("val_acc"),
96
+ }
97
+
98
+ @app.get("/health")
99
+ def health():
100
+ return {"status": "ok"}
101
+
102
+ @app.get("/classes")
103
+ def get_classes():
104
+ return {"classes": CLASS_NAMES}
105
+
106
+ @app.post("/predict")
107
+ async def predict(file: UploadFile = File(...)):
108
+ # Validate file type
109
+ if file.content_type not in ("image/jpeg", "image/png", "image/jpg", "image/webp"):
110
+ raise HTTPException(status_code=400, detail="Only JPEG/PNG images are accepted.")
111
+
112
+ try:
113
+ contents = await file.read()
114
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
115
+ except Exception:
116
+ raise HTTPException(status_code=400, detail="Could not read image file.")
117
+
118
+ tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
119
+
120
+ with torch.no_grad():
121
+ logits = model(tensor)
122
+ probs = torch.softmax(logits, dim=1)[0]
123
+
124
+ top5_probs, top5_idx = torch.topk(probs, k=min(5, NUM_CLASSES))
125
+
126
+ return JSONResponse({
127
+ "prediction": CLASS_NAMES[top5_idx[0].item()],
128
+ "confidence": round(top5_probs[0].item() * 100, 2),
129
+ "top5": [
130
+ {
131
+ "label": CLASS_NAMES[idx.item()],
132
+ "confidence": round(prob.item() * 100, 2)
133
+ }
134
+ for prob, idx in zip(top5_probs, top5_idx)
135
+ ],
136
+ "model_used": ARCH,
137
+ })
138
+
139
+ @app.post("/live")
140
+ async def live_predict(file: UploadFile = File(...)):
141
+ # Validate file type
142
+ if file.content_type not in ("image/jpeg", "image/png", "image/jpg", "image/webp"):
143
+ raise HTTPException(status_code=400, detail="Only JPEG/PNG images are accepted.")
144
+
145
+ try:
146
+ contents = await file.read()
147
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
148
+ except Exception:
149
+ raise HTTPException(status_code=400, detail="Could not read image file.")
150
+
151
+ tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
152
+
153
+ with torch.no_grad():
154
+ logits = model(tensor)
155
+ probs = torch.softmax(logits, dim=1)[0]
156
+
157
+ top5_probs, top5_idx = torch.topk(probs, k=min(5, NUM_CLASSES))
158
+
159
+ return JSONResponse({
160
+ "prediction": CLASS_NAMES[top5_idx[0].item()],
161
+ "confidence": round(top5_probs[0].item() * 100, 2),
162
+ "top5": [
163
+ {
164
+ "label": CLASS_NAMES[idx.item()],
165
+ "confidence": round(prob.item() * 100, 2)
166
+ }
167
+ for prob, idx in zip(top5_probs, top5_idx)
168
+ ],
169
+ "model_used": ARCH,
170
  })
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- fastapi==0.115.5
2
- uvicorn[standard]==0.32.0
3
- python-multipart==0.0.17
4
- torch==2.4.0+cpu
5
- torchvision==0.19.0+cpu
6
  Pillow==10.4.0
 
1
+ fastapi==0.115.5
2
+ uvicorn[standard]==0.32.0
3
+ python-multipart==0.0.17
4
+ torch==2.4.0+cpu
5
+ torchvision==0.19.0+cpu
6
  Pillow==10.4.0