curato / src /style_classifier.py
Sanj12's picture
Upload 14 files
5e90518 verified
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import models
import torch.nn as nn
from tqdm import tqdm
def load_model(model_path="models/style_model.pth", class_names=[]):
import torch
from torchvision import models
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
def predict_style(image_path, model, class_names):
from PIL import Image
from torchvision import transforms
import torch
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
return class_names[predicted.item()]