Jasmeet Singh commited on
import spaces
Browse files
app.py
CHANGED
|
@@ -9,13 +9,15 @@ from torchvision.utils import save_image #to save the generated images
|
|
| 9 |
from tqdm import tqdm # progress bar
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import gradio as gr
|
|
|
|
| 12 |
|
| 13 |
from styleTransfer import style_transfer
|
| 14 |
from dataTransform import tensor_to_image
|
| 15 |
|
| 16 |
-
device = 'cuda'
|
| 17 |
print(device)
|
| 18 |
|
|
|
|
| 19 |
def gradio_style_transfer(steps, content_image, style_image):
|
| 20 |
generated_tensor = style_transfer(content_image, style_image, total_steps= steps)
|
| 21 |
generated_image = tensor_to_image(generated_tensor)
|
|
|
|
| 9 |
from tqdm import tqdm # progress bar
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import gradio as gr
|
| 12 |
+
import spaces
|
| 13 |
|
| 14 |
from styleTransfer import style_transfer
|
| 15 |
from dataTransform import tensor_to_image
|
| 16 |
|
| 17 |
+
device = 'cuda'
|
| 18 |
print(device)
|
| 19 |
|
| 20 |
+
@spaces.GPU
|
| 21 |
def gradio_style_transfer(steps, content_image, style_image):
|
| 22 |
generated_tensor = style_transfer(content_image, style_image, total_steps= steps)
|
| 23 |
generated_image = tensor_to_image(generated_tensor)
|