zynt31 commited on
Commit
70b80f4
·
verified ·
1 Parent(s): 665d2e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -287,7 +287,17 @@ class RiceLeafValidator:
287
  from torchvision.models import resnet18
288
 
289
  model = resnet18(weights=None)
290
- model.fc = nn.Linear(model.fc.in_features, 2) # Binary classification
 
 
 
 
 
 
 
 
 
 
291
 
292
  state_dict = torch.load(model_path, map_location=self.device, weights_only=True)
293
  model.load_state_dict(state_dict)
@@ -613,6 +623,5 @@ async def predict_disease(file: UploadFile = File(...), use_ai_recommendation: b
613
 
614
  if __name__ == "__main__":
615
  import uvicorn
616
- # Hugging Face Spaces uses port 7860
617
- port = int(os.environ.get("PORT", 7860))
618
- uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
 
287
  from torchvision.models import resnet18
288
 
289
  model = resnet18(weights=None)
290
+
291
+ # The validator was trained with a Sequential FC head:
292
+ # fc.0 = Linear(512, hidden), fc.1 = ReLU, fc.2 = Dropout, fc.3 = Linear(hidden, 2)
293
+ # We need to match this architecture exactly
294
+ in_features = model.fc.in_features # 512 for ResNet18
295
+ model.fc = nn.Sequential(
296
+ nn.Linear(in_features, 256),
297
+ nn.ReLU(),
298
+ nn.Dropout(0.5),
299
+ nn.Linear(256, 2) # Binary classification: rice / not rice
300
+ )
301
 
302
  state_dict = torch.load(model_path, map_location=self.device, weights_only=True)
303
  model.load_state_dict(state_dict)
 
623
 
624
  if __name__ == "__main__":
625
  import uvicorn
626
+ # Production settings: no hot reload, proper host binding
627
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)