| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import ImageFolder | |
| from torchvision import models | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| def load_model(model_path="models/style_model.pth", class_names=[]): | |
| import torch | |
| from torchvision import models | |
| model = models.resnet18(pretrained=False) | |
| model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| def predict_style(image_path, model, class_names): | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torch | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| image = Image.open(image_path).convert("RGB") | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(image) | |
| _, predicted = torch.max(output, 1) | |
| return class_names[predicted.item()] | |