| import torch |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| from model import NetFeat, NetClassifier |
|
|
| |
| CLOTHING_CLASSES = [ |
| "T-shirt", "Shirt", "Shawl", "Dress", "Vest", "Underwear", "Cardigan", "Jacket", |
| "Sweater", "Hoodie", "Knitwear", "Chiffon", "Downcoat", "Suit" |
| ] |
|
|
| |
| def load_model(): |
| |
| model_filename = "CC_net.pt" |
|
|
| |
| net_feat = NetFeat(arch='resnet18', pretrained=False, dataset='Clothing1M') |
| net_cls = NetClassifier(feat_dim=512, nb_cls=14) |
|
|
| |
| state_dict = torch.load(model_filename, map_location=torch.device('cpu')) |
|
|
| |
| if "feat" in state_dict: |
| net_feat.load_state_dict(state_dict["feat"], strict=False) |
|
|
| |
| if "cls" in state_dict: |
| net_cls.load_state_dict(state_dict["cls"], strict=False) |
|
|
| |
| net_feat.eval() |
| net_cls.eval() |
|
|
| return net_feat, net_cls |
|
|
| |
| def preprocess_image(image): |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| image = Image.open(image).convert("RGB") |
| return transform(image).unsqueeze(0) |
|
|
| |
| def run_inference(image_path, net_feat, net_cls): |
| image_tensor = preprocess_image(image_path) |
|
|
| with torch.no_grad(): |
| feature_vector = net_feat(image_tensor) |
| output = net_cls(feature_vector) |
|
|
| predicted_index = output.argmax(dim=1).item() |
| return CLOTHING_CLASSES[predicted_index] |
|
|