marijanic's picture
Create app.py
43da96c verified
import streamlit as st
import numpy as np
from tqdm.auto import trange
import torch
from source import Model, Block
def cvtImg(img):
img = img - img.min()
img = (img / img.max())
return img.numpy().astype(np.float32)
def load_DM():
PATH = "./Diffusion-cuda"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
st.write("Devise used:", device)
model = Model().to(device)
st.write("Diffusion Model's architecture created.")
try:
# Load the entire model object
model = torch.load(PATH, map_location=device)
st.write("Diffusion Model loaded.")
except FileNotFoundError:
st.error("Model file not found at specified path.")
except Exception as e:
st.error(f"Error occurred while loading model: {e}")
return model, device
def predict_step_new(model, device):
xs = []
IMG_SIZE = 32
timesteps = 16
progress_text = "Denoising the image."
my_bar = st.progress(0, text=progress_text)
x = torch.randn(size=(1, 3, IMG_SIZE, IMG_SIZE), device=device)
placeholder = st.empty()
li = []
with torch.no_grad():
for i in trange(timesteps):
my_bar.progress(i*100//15, text=progress_text)
t = i
x = model(x, torch.full([8, 1], t, dtype=torch.float, device=device))
xs = x[0].permute(1, 2, 0)
xs = torch.clip(xs, -1, 1)
xs = cvtImg(xs.cpu())
images = np.transpose(xs, (2, 0, 1))
images = np.transpose(images, (1, 2, 0))
placeholder.image(images, caption=f"Time step {t}", width=500)
li.append(images)
st.markdown("Brand new **car** generated!")
return li
def main():
st.set_page_config(page_title="Diffusion Model App Demo")
st.header("Diffusion Model App ๐Ÿ–ผ๏ธ ๐Ÿš•", divider="rainbow")
st.markdown("""Made by:
Marijan$^\dag$, Thomas, Mauritz, Amanda and Baris for **COGITO** Diffusion Model""")
if "model" not in st.session_state:
st.session_state.model = []
if "generation" not in st.session_state:
st.session_state.generation = []
if "device" not in st.session_state:
st.session_state.device = None
placeholder_1 = st.empty()
isclick = placeholder_1.button('Import Diffusion Model')
if isclick:
st.session_state.model, st.session_state.device = load_DM()
placeholder_1.empty()
if st.button('Generate!'):
li = predict_step_new(st.session_state.model, st.session_state.device)
st.session_state.generation.append(li)
for num, gen in enumerate(st.session_state.generation):
st.write("Generation number", num+1)
st.image(gen, width=40)
if __name__ == "__main__":
main()