Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| from torch import nn | |
| from torchvision import transforms | |
| def create_gadgets_model(num_classes: int = 3, seed: int = 42): | |
| # Load pretrained model (weights only) | |
| weights = torchvision.models.ResNet50_Weights.DEFAULT | |
| model = torchvision.models.resnet50(weights=weights) | |
| # ✅ SAFE manual transforms (HF + Gradio compatible) | |
| gadget_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), # VERY IMPORTANT | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # Freeze base model | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Classifier head | |
| torch.manual_seed(seed) | |
| model.fc = nn.Sequential( | |
| nn.Linear(2048, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(128, num_classes) | |
| ) | |
| return model, gadget_transforms |