Aklavya commited on
Commit
97a30b6
·
verified ·
1 Parent(s): fdb1768

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline
4
+ # from PIL import Image
5
+ import concurrent.futures
6
+ # import time
7
+
8
+ # Load the models with appropriate pipeline
9
+ model_cache = {}
10
+
11
+
12
+ def load_model(model_name):
13
+ # Check if the model is already cached to avoid reloading every time
14
+ if model_name in model_cache:
15
+ return model_cache[model_name]
16
+
17
+ print(f"Loading model: {model_name}")
18
+ try:
19
+ if model_name == "SG161222/RealVisXL_V5.0_Lightning":
20
+ model = StableDiffusionXLPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
21
+ else:
22
+ model = StableDiffusionXLPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
23
+
24
+ model.to("cuda")
25
+ model_cache[model_name] = model # Cache the model for future use
26
+ print("Model loaded successfully.")
27
+ return model
28
+ except Exception as e:
29
+ print(f"Error loading model: {e}")
30
+ return None
31
+
32
+
33
+ # Function to generate the image with a timeout
34
+ def generate_image_with_timeout(prompt, model_name):
35
+ # Set a timeout for the image generation (120 seconds)
36
+ timeout = 180 # 120 seconds
37
+
38
+ try:
39
+ # Use ThreadPoolExecutor to handle the timeout
40
+ with concurrent.futures.ThreadPoolExecutor() as executor:
41
+ future = executor.submit(generate_image, prompt, model_name)
42
+ return future.result(timeout=timeout) # Will raise TimeoutError if the process exceeds timeout
43
+
44
+ except concurrent.futures.TimeoutError:
45
+ return "Error: The image generation timed out. Please try again."
46
+
47
+
48
+ # Function to generate the image
49
+ def generate_image(prompt, model_name):
50
+ # Load the appropriate model
51
+ model = load_model(model_name)
52
+
53
+ if model is None:
54
+ return "Error loading the model."
55
+
56
+ try:
57
+ # Generate the image from the prompt
58
+ with torch.no_grad():
59
+ image = model(prompt).images[0]
60
+ return image
61
+ except Exception as e:
62
+ print(f"Error generating image: {e}")
63
+ return "Error generating the image."
64
+
65
+
66
+ # Define the Gradio interface using gr.Blocks
67
+ def create_gradio_interface():
68
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
+ # Add a heading at the top of the screen
70
+ gr.Markdown("""
71
+ <h1 style="
72
+ text-align: center;
73
+ color: white;
74
+ font-weight: bold;
75
+ text-transform: uppercase;
76
+ text-decoration: underline;
77
+ margin-top: 30px;
78
+ font-family: 'Arial', sans-serif;
79
+ background: linear-gradient(45deg, #ff6b6b, #f06595);
80
+ padding: 10px 20px;
81
+ border-radius: 15px;
82
+ box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3);
83
+ ">
84
+ SNAPSCRIBE
85
+ </h1>
86
+ """)
87
+
88
+ # Create a Row for input-output arrangement
89
+ with gr.Row():
90
+ # Create a column for the left input section
91
+ with gr.Column(scale=0.3, min_width=300):
92
+ prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="e.g., A futuristic city skyline")
93
+ model_input = gr.Dropdown(
94
+ choices=["SG161222/RealVisXL_V5.0_Lightning", "SG161222/RealVisXL_V4.0_Lightning"],
95
+ label="Choose model",
96
+ value="SG161222/RealVisXL_V5.0_Lightning"
97
+ )
98
+ submit_button = gr.Button("Generate Image")
99
+
100
+ # Create a column for the right output image section with reduced height (20% smaller)
101
+ with gr.Column(scale=0.7, min_width=600):
102
+ output_image = gr.Image(label="Generated Image", height=640) # Reduced height by 20%
103
+
104
+ # Bind the function to the Gradio interface
105
+ submit_button.click(fn=generate_image_with_timeout, inputs=[prompt_input, model_input], outputs=output_image)
106
+
107
+ # Add footer with a styled span for text in the footer
108
+ gr.Markdown("""
109
+ <div style="position: relative; left: 0; bottom: 0; width: 100%; background-color: #0B0F19; color: white; text-align: center; padding: 10px 0;">
110
+ <p>Developed with ❤ by Aklavya (Bucky)</p>
111
+ </div>
112
+ """)
113
+
114
+ # Launch the interface
115
+ demo.launch(share=True)
116
+
117
+
118
+ # Launch the Gradio interface
119
+ create_gradio_interface()