srijaydeshpande's picture
Update app.py
540c30e verified
raw
history blame
1.39 kB
import os
import glob
import argparse
import shutil
from PIL import Image
import PIL
import gradio as gr
import random
import numpy as np
from diffusion import generate_latent
from vq_vae import create_mask
from huggingface_hub import snapshot_download
import spaces
# model_dir = 'trained_models'
from huggingface_hub import login
login(token = os.getenv('HF_TOKEN'))
local_dir = snapshot_download(
repo_id="srijaydeshpande/diffusion"
)
@spaces.GPU(duration=120)
def create_image(cancer_type):
tmp_dir = "./tmp"
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
generate_latent(model_dir, cancer_type, tmp_dir)
create_mask(model_dir, "./tmp", "./tmp/test_masks")
os.system('python pix2pixhd_test.py --name diffusion_dp --dataroot ./tmp --label_nc 0 --no_instance --resize_or_crop none')
image_dir = "./tmp/diffusion_dp/test_latest/images"
input_label_image = Image.open(os.path.join(image_dir, "sample_input_label.jpg"))
synthesized_image = Image.open(os.path.join(image_dir, "sample_synthesized_image.jpg"))
return input_label_image, synthesized_image
demo = gr.Interface(
create_image,
inputs=gr.Radio(choices=["benign", "malignant"], label="Choose Type", value="benign"),
outputs=[gr.Image(), gr.Image()],
title="Diffusion based Image Generation"
)
demo.launch()
# create_image('benign')