SkillForge45 commited on
Commit
af52a92
·
verified ·
1 Parent(s): 8388b4e

Create style_transfer.py

Browse files
Files changed (1) hide show
  1. style_transfer.py +20 -0
style_transfer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def transfer_style(image_path, target_style_name):
2
+
3
+ G = ConditionalGenerator(num_styles=3)
4
+ G.load_state_dict(torch.load("generator.pth"))
5
+
6
+
7
+ img = Image.open(image_path).convert("RGB")
8
+ transform = T.Compose([T.Resize(256), T.ToTensor(), T.Normalize(0.5, 0.5)])
9
+ img_tensor = transform(img).unsqueeze(0)
10
+
11
+
12
+ style_id = style_name_to_id[target_style_name]
13
+
14
+
15
+ with torch.no_grad():
16
+ stylized = G(img_tensor, torch.tensor([style_id]))
17
+
18
+
19
+ result = ToPILImage()((stylized.squeeze() + 1) / 2) # Денормализация [0, 1]
20
+ result.save("stylized.jpg")