Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import torch | |
| from torch import autocast | |
| from diffusers import StableDiffusionPipeline | |
| import gradio as gr # Import Gradio | |
| from PIL import Image | |
| # Load the environment variables from .env | |
| load_dotenv() | |
| class StableBuddyApp: | |
| def __init__(self): | |
| # Set up the Stable Diffusion pipeline | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" # Store device as a class attribute | |
| # Get the auth_token from the environment variable | |
| auth_token = os.getenv("AUTH_TOKEN") | |
| if not auth_token: | |
| raise ValueError("AUTH_TOKEN environment variable is not set.") | |
| # Use float16 for GPU and float32 for CPU to manage VRAM | |
| torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, revision='fp16', torch_dtype=torch_dtype, use_auth_token=auth_token | |
| ) | |
| self.pipe.to(self.device) | |
| def generate_image(self, prompt): | |
| """Generate an image based on the prompt.""" | |
| try: | |
| # Use autocast only for GPU | |
| if self.device == "cuda": | |
| with autocast(self.device): | |
| image = self.pipe(prompt, guidance_scale=8.5).images[0] | |
| else: | |
| image = self.pipe(prompt, guidance_scale=8.5).images[0] | |
| # Ensure the directory exists | |
| output_dir = 'data' | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| # Save the generated image temporarily | |
| image_path = os.path.join(output_dir, 'generated_image.png') | |
| image.save(image_path) | |
| return image_path # Return the image path for Gradio to display | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return None | |
| # Create an instance of the StableBuddyApp | |
| stable_buddy_app = StableBuddyApp() | |
| # Create Gradio Interface with separate buttons | |
| def generate_and_download(prompt): | |
| image_path = stable_buddy_app.generate_image(prompt) | |
| return image_path, image_path # Return image for display and for download link | |
| # Create Gradio Interface | |
| iface = gr.Interface( | |
| fn=generate_and_download, # Function to call | |
| inputs=gr.Textbox(label="Enter Prompt"), # Text input for the prompt | |
| outputs=[gr.Image(type="filepath", label="Generated Image"), gr.File(label="Download Image")], # Two outputs for display and download | |
| title="Stable Buddy", | |
| description="Generate images using Stable Diffusion." | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |