image-Style / predict.py
d-e-e-k-11's picture
Upload folder using huggingface_hub
d1bfee5 verified
import torch
from models import Generator
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import os
import matplotlib.pyplot as plt
def predict(model, image_path, device="cpu"):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
prediction = model(image_tensor)
prediction = prediction.squeeze(0).cpu().detach().numpy()
prediction = (prediction * 0.5 + 0.5).transpose(1, 2, 0)
prediction = (prediction * 255).astype(np.uint8)
return prediction
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
gen_Z = Generator(img_channels=3).to(device)
# Check if a checkpoint exists
checkpoint_path = "genz.pth.tar"
if os.path.exists(checkpoint_path):
gen_Z.load_state_dict(torch.load(checkpoint_path, map_location=device))
print(f"Loaded checkpoint from {checkpoint_path}")
else:
print("Using untrained model (no checkpoint found).")
test_image = "data/horse2zebra/testA/n02381460_1010.jpg" # Example horse image
if os.path.exists(test_image):
result = predict(gen_Z, test_image, device)
plt.imshow(result)
plt.title("Style Transferred Image (Zebra)")
plt.axis("off")
plt.savefig("prediction_result.png")
print("Prediction saved to prediction_result.png")
else:
print(f"Test image {test_image} not found.")
if __name__ == "__main__":
main()