import torch import torchvision from torch import nn from torchvision import transforms def create_gadgets_model(num_classes: int = 3, seed: int = 42): # Load pretrained model (weights only) weights = torchvision.models.ResNet50_Weights.DEFAULT model = torchvision.models.resnet50(weights=weights) # ✅ SAFE manual transforms (HF + Gradio compatible) gadget_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), # VERY IMPORTANT transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # Freeze base model for param in model.parameters(): param.requires_grad = False # Classifier head torch.manual_seed(seed) model.fc = nn.Sequential( nn.Linear(2048, 128), nn.ReLU(inplace=True), nn.Linear(128, num_classes) ) return model, gadget_transforms