Winston de Jong commited on
Commit
1bf5f86
·
1 Parent(s): 0da216f

implement model

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -44,7 +44,7 @@ def load_model(repo_id):
44
  # Initialize the ResNet-18 architecture
45
  model = torchvision.models.resnet18(pretrained=True) # TODO: does it matter if this is set to true or false?
46
  num_ftrs = model.fc.in_features
47
- model.fc = nn.Linear(num_ftrs, 100) # Adjust for your task (e.g., 128 classes)
48
  # TODO: check if this number^^ corresponds to the number of classes
49
 
50
  # Load the model weights
@@ -78,10 +78,13 @@ def process_image_str(groupImageFilePath: str):
78
  # do AI stuff here
79
  with torch.no_grad():
80
  outputs_t = model(intputTensor)
81
- _, pred_t = torch.max(outputs_t, dim=1)
 
 
82
  outputLabels.append(pred_t.item())
83
 
84
  #return gr.Image(image)
 
85
  return outputLabels.pop(0)
86
 
87
 
 
44
  # Initialize the ResNet-18 architecture
45
  model = torchvision.models.resnet18(pretrained=True) # TODO: does it matter if this is set to true or false?
46
  num_ftrs = model.fc.in_features
47
+ model.fc = nn.Linear(num_ftrs, 128) # Adjust for your task (e.g., 128 classes)
48
  # TODO: check if this number^^ corresponds to the number of classes
49
 
50
  # Load the model weights
 
78
  # do AI stuff here
79
  with torch.no_grad():
80
  outputs_t = model(intputTensor)
81
+ print(outputs_t)
82
+ temp, pred_t = torch.max(outputs_t, dim=1)
83
+ print(temp)
84
  outputLabels.append(pred_t.item())
85
 
86
  #return gr.Image(image)
87
+ print(outputLabels)
88
  return outputLabels.pop(0)
89
 
90