ishaal007 commited on
Commit
2703b7d
·
verified ·
1 Parent(s): 6a07f93

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +21 -13
model.py CHANGED
@@ -1,25 +1,33 @@
1
  import torch
2
  import torchvision
3
-
4
  from torch import nn
 
5
 
6
-
7
- def create_gadgets_model(num_classes:int=3,
8
- seed:int=42):
9
- # Create EffNetB2 pretrained weights, transforms and model
10
  weights = torchvision.models.ResNet50_Weights.DEFAULT
11
- transforms = weights.transforms()
12
  model = torchvision.models.resnet50(weights=weights)
13
 
14
- # Freeze all layers in base model
 
 
 
 
 
 
 
 
 
 
15
  for param in model.parameters():
16
  param.requires_grad = False
17
 
18
- # Change classifier head with random seed for reproducibility
19
  torch.manual_seed(seed)
20
  model.fc = nn.Sequential(
21
- nn.Linear(2048, 128),
22
- nn.ReLU(inplace=True),
23
- nn.Linear(in_features= 128,out_features=num_classes))
24
-
25
- return model, transforms
 
 
1
  import torch
2
  import torchvision
 
3
  from torch import nn
4
+ from torchvision import transforms
5
 
6
+ def create_gadgets_model(num_classes: int = 3, seed: int = 42):
7
+ # Load pretrained model (weights only)
 
 
8
  weights = torchvision.models.ResNet50_Weights.DEFAULT
 
9
  model = torchvision.models.resnet50(weights=weights)
10
 
11
+ # SAFE manual transforms (HF + Gradio compatible)
12
+ gadget_transforms = transforms.Compose([
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(), # VERY IMPORTANT
15
+ transforms.Normalize(
16
+ mean=[0.485, 0.456, 0.406],
17
+ std=[0.229, 0.224, 0.225]
18
+ )
19
+ ])
20
+
21
+ # Freeze base model
22
  for param in model.parameters():
23
  param.requires_grad = False
24
 
25
+ # Classifier head
26
  torch.manual_seed(seed)
27
  model.fc = nn.Sequential(
28
+ nn.Linear(2048, 128),
29
+ nn.ReLU(inplace=True),
30
+ nn.Linear(128, num_classes)
31
+ )
32
+
33
+ return model, gadget_transforms