Spaces:
Build error
Build error
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import re | |
| def split_into_sentences(text): | |
| """ | |
| Splits the input text into individual sentences. | |
| This helps in identifying key scenes for image generation. | |
| """ | |
| # Simple sentence splitter based on punctuation | |
| sentences = re.split(r'(?<=[.!?]) +', text) | |
| return sentences | |
| def generate_comic_strip(story): | |
| """ | |
| Generates a comic strip from the input story. | |
| Parameters: | |
| - story (str): The user's story prompt. | |
| Returns: | |
| - comic_strip (list): A list of generated images representing each scene. | |
| """ | |
| if pipe is None: | |
| return ["https://via.placeholder.com/512x512.png?text=Model+Not+Loaded"] | |
| # Split the story into sentences to identify key scenes | |
| scenes = split_into_sentences(story) | |
| # Limit the number of scenes to prevent excessive image generation | |
| max_scenes = 3 | |
| scenes = scenes[:max_scenes] | |
| comic_strip = [] | |
| for idx, scene in enumerate(scenes): | |
| try: | |
| # Generate image for each scene with optimizations | |
| image = pipe( | |
| scene, | |
| num_inference_steps=20, # Reduced steps for faster generation | |
| height=256, # Reduced resolution | |
| width=256, # Reduced resolution | |
| guidance_scale=7.5, # Default guidance scale | |
| ).images[0] | |
| comic_strip.append(image) | |
| except Exception as e: | |
| # In case of any error during image generation, append a placeholder image | |
| print(f"Error generating image for scene {idx+1}: {e}") | |
| comic_strip.append("https://via.placeholder.com/512x512.png?text=Image+Unavailable") | |
| return comic_strip | |
| def main(): | |
| """ | |
| Sets up the Gradio interface for the GenArt Narrative application. | |
| """ | |
| # Define the input component: A textbox for the user to input their story | |
| input_text = gr.Textbox( | |
| lines=5, | |
| placeholder="Enter your short story here...", | |
| label="Story Prompt" | |
| ) | |
| # Define the output component: A gallery to display the generated comic strip | |
| output_gallery = gr.Gallery( | |
| label="Generated Comic Strip", | |
| columns=3, | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_comic_strip, # Function to process input and generate output | |
| inputs=input_text, # Input component | |
| outputs=output_gallery, # Output component | |
| title="GenArt Narrative", # Title of the app | |
| description="Transform your short stories into engaging comic strips using AI-powered image generation.", # Description | |
| examples=[ # Example inputs for demonstration | |
| ["A young wizard discovers a hidden magical forest and befriends a talking owl."], | |
| ["An astronaut lands on a distant planet and encounters alien life forms."] | |
| ], | |
| allow_flagging="never", # Disable flagging of outputs | |
| theme="default", # You can choose other themes like "huggingface" | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |
| # Load the model globally to avoid reloading it for each request | |
| pipe = None | |
| try: | |
| print("Loading Stable Diffusion model...") | |
| # Initialize the Stable Diffusion pipeline with optimizations | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-1-base", # Model name | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| low_cpu_mem_usage=True, # Optimize for CPU usage | |
| safety_checker=None, # Disable safety checker to speed up loading | |
| force_download=True # Force download to avoid resume_download warning | |
| ) | |
| pipe = pipe.to("cpu") # Move the model to CPU | |
| pipe.enable_attention_slicing() # Reduce memory usage | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| pipe = None | |
| if __name__ == "__main__": | |
| if pipe is not None: | |
| main() | |
| else: | |
| print("Failed to load the model. Please check the error messages above.") | |