Spaces:
Paused
Paused
| import streamlit as st | |
| from PIL import Image | |
| from inference import inference | |
| import torch | |
| import io | |
| from diffusion import DiffusionImageAPI | |
| import math | |
| def main(): | |
| genres_dict = { | |
| 'Action': 1, | |
| 'Adventure': 2, | |
| 'Animation': 3, | |
| 'Comedy': 4, | |
| 'Drama': 5, | |
| 'Family': 6, | |
| 'Horror': 7, | |
| 'Music': 8, | |
| 'Romance': 9, | |
| 'Science Fiction': 10, | |
| 'Western': 11, | |
| 'Fantasy': 12, | |
| 'Thriller': 13 | |
| } | |
| st.title("Movie Diffusion") | |
| cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |
| # Add a sidebar for genre selection | |
| #genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys())) | |
| selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys())) | |
| progress_placeholder = st.empty() | |
| image_placeholder = st.empty() | |
| # Button to trigger image generation | |
| if st.button('Generate Image'): | |
| for genre in selected_genres: | |
| code = genres_dict[genre] | |
| cond[code-1] = code | |
| if torch.any(cond != 0): | |
| random_number = torch.randint(0, 13, (1,)).item() | |
| cond[random_number] = random_number + 1 | |
| def callback(image, progress): | |
| image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0)) | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format="PNG") | |
| img_buffer.seek(0) | |
| # Update the content of the placeholders | |
| progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%") | |
| image_placeholder.image(img_buffer, caption='Generated Image', width=300) | |
| inference(cond, callback=callback) | |
| if __name__ == "__main__": | |
| main() | |