Update generator_model_40.pth loading in app.py
Browse files
app.py
CHANGED
|
@@ -57,7 +57,7 @@ if (device.type == 'cuda') and (ngpu > 1):
|
|
| 57 |
generator = nn.DataParallel(generator, list(range(ngpu)))
|
| 58 |
discriminator = nn.DataParallel(discriminator, list(range(ngpu)))
|
| 59 |
seg_model = nn.DataParallel(seg_model, list(range(ngpu)))
|
| 60 |
-
|
| 61 |
|
| 62 |
def inference(sketch_path,label):
|
| 63 |
transform_sketch = T.Compose(
|
|
@@ -79,8 +79,8 @@ def inference(sketch_path,label):
|
|
| 79 |
return fake_images.cpu().detach().numpy()
|
| 80 |
|
| 81 |
|
| 82 |
-
audio_1 = gr.Image(sources="upload", type="filepath", label="
|
| 83 |
-
audio_2 = gr.Image(sources="upload", type="filepath", label="
|
| 84 |
# text_output = gr.Textbox(label="Similarity Score")
|
| 85 |
image_out = gr.Image(label="Generated Image")
|
| 86 |
gr.Interface(
|
|
|
|
| 57 |
generator = nn.DataParallel(generator, list(range(ngpu)))
|
| 58 |
discriminator = nn.DataParallel(discriminator, list(range(ngpu)))
|
| 59 |
seg_model = nn.DataParallel(seg_model, list(range(ngpu)))
|
| 60 |
+
generator.load_state_dict(torch.load('generator_model_40.pth'))
|
| 61 |
|
| 62 |
def inference(sketch_path,label):
|
| 63 |
transform_sketch = T.Compose(
|
|
|
|
| 79 |
return fake_images.cpu().detach().numpy()
|
| 80 |
|
| 81 |
|
| 82 |
+
audio_1 = gr.Image(sources="upload", type="filepath", label="img 1")
|
| 83 |
+
audio_2 = gr.Image(sources="upload", type="filepath", label="img 2")
|
| 84 |
# text_output = gr.Textbox(label="Similarity Score")
|
| 85 |
image_out = gr.Image(label="Generated Image")
|
| 86 |
gr.Interface(
|