import os import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision import models import torch.nn as nn from tqdm import tqdm def load_model(model_path="models/style_model.pth", class_names=[]): import torch from torchvision import models model = models.resnet18(pretrained=False) model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model def predict_style(image_path, model, class_names): from PIL import Image from torchvision import transforms import torch transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) image = Image.open(image_path).convert("RGB") image = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image) _, predicted = torch.max(output, 1) return class_names[predicted.item()]