Spaces:
Runtime error
Runtime error
Om-Alve
commited on
Commit
·
319f6be
1
Parent(s):
a459d13
downscaling
Browse files
app.py
CHANGED
|
@@ -35,7 +35,7 @@ class StyleTransfer(nn.Module):
|
|
| 35 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 36 |
|
| 37 |
def image_merger(content, style,beta=10,device=device):
|
| 38 |
-
size =
|
| 39 |
alpha = 1
|
| 40 |
beta *= 1000
|
| 41 |
content = Image.fromarray(content)
|
|
@@ -52,7 +52,7 @@ def image_merger(content, style,beta=10,device=device):
|
|
| 52 |
generator = StyleTransfer().to(device).eval()
|
| 53 |
opt = torch.optim.Adam([generated],lr=0.06)
|
| 54 |
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.9) # Learning rate scheduler
|
| 55 |
-
num_epochs = 30 if device
|
| 56 |
style_features,_ = generator(style)
|
| 57 |
_,content_features = generator(content)
|
| 58 |
loop = tqdm(range(num_epochs),leave=False)
|
|
@@ -74,7 +74,7 @@ def image_merger(content, style,beta=10,device=device):
|
|
| 74 |
total_loss.backward(retain_graph=True)
|
| 75 |
opt.step()
|
| 76 |
scheduler.step()
|
| 77 |
-
if total_loss < 200 and device
|
| 78 |
break
|
| 79 |
print(total_loss.item())
|
| 80 |
img = np.array(generated.cpu().detach().squeeze(0).permute(1,2,0))
|
|
|
|
| 35 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 36 |
|
| 37 |
def image_merger(content, style,beta=10,device=device):
|
| 38 |
+
size = 300
|
| 39 |
alpha = 1
|
| 40 |
beta *= 1000
|
| 41 |
content = Image.fromarray(content)
|
|
|
|
| 52 |
generator = StyleTransfer().to(device).eval()
|
| 53 |
opt = torch.optim.Adam([generated],lr=0.06)
|
| 54 |
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.9) # Learning rate scheduler
|
| 55 |
+
num_epochs = 30 if device != "cuda" else 100
|
| 56 |
style_features,_ = generator(style)
|
| 57 |
_,content_features = generator(content)
|
| 58 |
loop = tqdm(range(num_epochs),leave=False)
|
|
|
|
| 74 |
total_loss.backward(retain_graph=True)
|
| 75 |
opt.step()
|
| 76 |
scheduler.step()
|
| 77 |
+
if total_loss < 200 and device!='cuda':
|
| 78 |
break
|
| 79 |
print(total_loss.item())
|
| 80 |
img = np.array(generated.cpu().detach().squeeze(0).permute(1,2,0))
|