PrachiY commited on
Commit
738116d
·
verified ·
1 Parent(s): 0f2b2af

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -13,8 +13,9 @@ model_path = hf_hub_download(
13
 
14
  # ✅ Load the Model
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- model = models.resnet50(pretrained=False)
17
 
 
18
  checkpoint = torch.load(model_path, map_location=device)
19
 
20
  if "model" in checkpoint:
@@ -24,6 +25,7 @@ elif "state_dict" in checkpoint:
24
  else:
25
  model.load_state_dict(checkpoint, strict=False)
26
 
 
27
  model.fc = torch.nn.Linear(2048, 21)
28
  model.to(device)
29
  model.eval()
@@ -42,7 +44,8 @@ def preprocess_image(image):
42
  transforms.ToTensor(),
43
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44
  ])
45
- return transform(image).unsqueeze(0).to(device)
 
46
 
47
  # ✅ Prediction Function
48
  def predict(image):
 
13
 
14
  # ✅ Load the Model
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"✅ Using device: {device}")
17
 
18
+ model = models.resnet50(pretrained=False)
19
  checkpoint = torch.load(model_path, map_location=device)
20
 
21
  if "model" in checkpoint:
 
25
  else:
26
  model.load_state_dict(checkpoint, strict=False)
27
 
28
+ # Ensure correct output layer (21 classes for Clothing1M)
29
  model.fc = torch.nn.Linear(2048, 21)
30
  model.to(device)
31
  model.eval()
 
44
  transforms.ToTensor(),
45
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
46
  ])
47
+ image = transform(image).unsqueeze(0).to(device) # Ensure tensor is on GPU if available
48
+ return image
49
 
50
  # ✅ Prediction Function
51
  def predict(image):