Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from tqdm.auto import trange | |
| import torch | |
| from source import Model, Block | |
| def cvtImg(img): | |
| img = img - img.min() | |
| img = (img / img.max()) | |
| return img.numpy().astype(np.float32) | |
| def load_DM(): | |
| PATH = "./Diffusion-cuda" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| st.write("Devise used:", device) | |
| model = Model().to(device) | |
| st.write("Diffusion Model's architecture created.") | |
| try: | |
| # Load the entire model object | |
| model = torch.load(PATH, map_location=device) | |
| st.write("Diffusion Model loaded.") | |
| except FileNotFoundError: | |
| st.error("Model file not found at specified path.") | |
| except Exception as e: | |
| st.error(f"Error occurred while loading model: {e}") | |
| return model, device | |
| def predict_step_new(model, device): | |
| xs = [] | |
| IMG_SIZE = 32 | |
| timesteps = 16 | |
| progress_text = "Denoising the image." | |
| my_bar = st.progress(0, text=progress_text) | |
| x = torch.randn(size=(1, 3, IMG_SIZE, IMG_SIZE), device=device) | |
| placeholder = st.empty() | |
| li = [] | |
| with torch.no_grad(): | |
| for i in trange(timesteps): | |
| my_bar.progress(i*100//15, text=progress_text) | |
| t = i | |
| x = model(x, torch.full([8, 1], t, dtype=torch.float, device=device)) | |
| xs = x[0].permute(1, 2, 0) | |
| xs = torch.clip(xs, -1, 1) | |
| xs = cvtImg(xs.cpu()) | |
| images = np.transpose(xs, (2, 0, 1)) | |
| images = np.transpose(images, (1, 2, 0)) | |
| placeholder.image(images, caption=f"Time step {t}", width=500) | |
| li.append(images) | |
| st.markdown("Brand new **car** generated!") | |
| return li | |
| def main(): | |
| st.set_page_config(page_title="Diffusion Model App Demo") | |
| st.header("Diffusion Model App ๐ผ๏ธ ๐", divider="rainbow") | |
| st.markdown("""Made by: | |
| Marijan$^\dag$, Thomas, Mauritz, Amanda and Baris for **COGITO** Diffusion Model""") | |
| if "model" not in st.session_state: | |
| st.session_state.model = [] | |
| if "generation" not in st.session_state: | |
| st.session_state.generation = [] | |
| if "device" not in st.session_state: | |
| st.session_state.device = None | |
| placeholder_1 = st.empty() | |
| isclick = placeholder_1.button('Import Diffusion Model') | |
| if isclick: | |
| st.session_state.model, st.session_state.device = load_DM() | |
| placeholder_1.empty() | |
| if st.button('Generate!'): | |
| li = predict_step_new(st.session_state.model, st.session_state.device) | |
| st.session_state.generation.append(li) | |
| for num, gen in enumerate(st.session_state.generation): | |
| st.write("Generation number", num+1) | |
| st.image(gen, width=40) | |
| if __name__ == "__main__": | |
| main() | |