Create style_transfer.py
Browse files- style_transfer.py +20 -0
style_transfer.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def transfer_style(image_path, target_style_name):
|
| 2 |
+
|
| 3 |
+
G = ConditionalGenerator(num_styles=3)
|
| 4 |
+
G.load_state_dict(torch.load("generator.pth"))
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
img = Image.open(image_path).convert("RGB")
|
| 8 |
+
transform = T.Compose([T.Resize(256), T.ToTensor(), T.Normalize(0.5, 0.5)])
|
| 9 |
+
img_tensor = transform(img).unsqueeze(0)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
style_id = style_name_to_id[target_style_name]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
with torch.no_grad():
|
| 16 |
+
stylized = G(img_tensor, torch.tensor([style_id]))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
result = ToPILImage()((stylized.squeeze() + 1) / 2) # Денормализация [0, 1]
|
| 20 |
+
result.save("stylized.jpg")
|