Aklavya commited on
Commit
03a12e6
·
verified ·
1 Parent(s): 20d431f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -44
app.py CHANGED
@@ -1,89 +1,145 @@
1
  import os
2
- import random
3
  import uuid
 
 
 
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
- import spaces
8
  import torch
 
9
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
10
- from typing import Tuple
11
 
 
 
 
 
12
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
13
  styles = {
14
  "3840 x 2160": (
15
- "hyper-realistic image of {prompt}. lifelike, authentic, natural colors, true-to-life details, landscape image, realistic lighting, immersive, highly detailed",
 
16
  "unrealistic, low resolution, artificial, over-saturated, distorted, fake",
17
  ),
18
  "Style Zero": ("{prompt}", ""),
19
  }
20
  DEFAULT_STYLE_NAME = "3840 x 2160"
21
-
22
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
23
- return p.replace("{prompt}", positive), n + negative
 
24
 
25
- def load_and_prepare_model():
 
 
 
26
  model_id = "SG161222/RealVisXL_V5.0_Lightning"
 
 
 
 
27
  pipe = StableDiffusionXLPipeline.from_pretrained(
28
  model_id,
29
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
  use_safetensors=True,
31
  add_watermarker=False,
32
- ).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
 
 
33
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return pipe
35
 
 
 
36
  model = load_and_prepare_model()
37
 
 
 
 
 
38
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
39
  if randomize_seed:
40
  seed = random.randint(0, np.iinfo(np.int32).max)
41
- return seed
 
42
 
43
- def save_image(img):
44
- unique_name = str(uuid.uuid4()) + ".png"
 
45
  img.save(unique_name)
46
  return unique_name
47
 
 
 
 
 
48
  @spaces.GPU(duration=60, enable_queue=True)
49
  def generate(
50
- prompt: str,
51
- seed: int = 1,
52
- width: int = 1024,
53
- height: int = 1024,
54
- guidance_scale: float = 3,
55
- num_inference_steps: int = 25,
56
- randomize_seed: bool = False,
57
  ):
58
  global model
59
 
60
- seed = int(randomize_seed_fn(seed, randomize_seed))
 
 
 
 
 
 
 
 
61
  generator = torch.Generator(device=model.device).manual_seed(seed)
62
 
63
  positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
64
 
65
- options = {
66
- "prompt": [positive_prompt],
67
- "negative_prompt": [negative_prompt],
68
- "width": width,
69
- "height": height,
70
- "guidance_scale": guidance_scale,
71
- "num_inference_steps": num_inference_steps,
72
- "generator": generator,
73
- "output_type": "pil",
74
- }
75
-
76
- images = model(**options).images
77
- image_path = save_image(images[0]) # Saving the first generated image
78
  return image_path
79
 
 
 
 
 
80
  with gr.Blocks(theme="soft") as demo:
81
- # Block for "SNAPSCRIBE" centered at the top
82
  with gr.Row():
83
  with gr.Column(scale=12, elem_id="title_block"):
84
- gr.Markdown("<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>")
85
- gr.Markdown("<h2 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>Developed using RealVisXL_V5.0_Lightning model with ❤ by Aklavya</h2>")
86
-
 
 
 
 
87
  with gr.Row():
88
  with gr.Column(scale=3):
89
  prompt = gr.Textbox(
@@ -91,9 +147,15 @@ with gr.Blocks(theme="soft") as demo:
91
  placeholder="Describe the image you want to create",
92
  lines=2,
93
  )
94
- run_button = gr.Button("Generate Image")
 
 
 
 
 
 
 
95
 
96
- # Example prompts box
97
  example_prompts_text = (
98
  "Dew-covered spider web in morning sunlight, with blurred greenery\n"
99
  "--------------------------------------------\n"
@@ -104,9 +166,9 @@ with gr.Blocks(theme="soft") as demo:
104
  "Autumn forest with golden leaves, sunlight through trees, and a breeze"
105
  )
106
 
107
- example_prompts = gr.Textbox(
108
  value=example_prompts_text,
109
- lines=5,
110
  label="Sample Inputs",
111
  interactive=False,
112
  )
@@ -116,13 +178,14 @@ with gr.Blocks(theme="soft") as demo:
116
  label="Generated Image",
117
  type="filepath",
118
  elem_id="output_image",
119
- height=650 # Increased the height by 100%
120
  )
121
 
122
  run_button.click(
123
  fn=generate,
124
- inputs=[prompt],
125
  outputs=[result_image],
 
126
  )
127
 
128
- demo.launch()
 
 
1
  import os
 
2
  import uuid
3
+ import random
4
+ from typing import Tuple
5
+
6
  import gradio as gr
7
  import numpy as np
8
  from PIL import Image
 
9
  import torch
10
+ import spaces
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
12
 
13
+
14
+ # -----------------------
15
+ # Style handling
16
+ # -----------------------
17
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
18
  styles = {
19
  "3840 x 2160": (
20
+ "hyper-realistic image of {prompt}. lifelike, authentic, natural colors, "
21
+ "true-to-life details, landscape image, realistic lighting, immersive, highly detailed",
22
  "unrealistic, low resolution, artificial, over-saturated, distorted, fake",
23
  ),
24
  "Style Zero": ("{prompt}", ""),
25
  }
26
  DEFAULT_STYLE_NAME = "3840 x 2160"
 
27
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
28
+ return p.replace("{prompt}", positive), (n + (" " + negative if negative else "")).strip()
29
+
30
 
31
+ # -----------------------
32
+ # Model loader
33
+ # -----------------------
34
+ def load_and_prepare_model() -> StableDiffusionXLPipeline:
35
  model_id = "SG161222/RealVisXL_V5.0_Lightning"
36
+ use_cuda = torch.cuda.is_available()
37
+ dtype = torch.float16 if use_cuda else torch.float32
38
+ device = torch.device("cuda:0" if use_cuda else "cpu")
39
+
40
  pipe = StableDiffusionXLPipeline.from_pretrained(
41
  model_id,
42
+ torch_dtype=dtype,
43
  use_safetensors=True,
44
  add_watermarker=False,
45
+ )
46
+
47
+ # Use a stable, fast scheduler
48
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
49
+
50
+ # Memory/throughput optimizations
51
+ try:
52
+ pipe.enable_xformers_memory_efficient_attention()
53
+ except Exception:
54
+ # xformers may be missing on CPU or some build types — that's fine
55
+ pass
56
+
57
+ # Optional, safe perf knobs on CUDA
58
+ if use_cuda:
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ torch.set_grad_enabled(False)
61
+
62
+ pipe = pipe.to(device)
63
  return pipe
64
 
65
+
66
+ # Global model (loaded once per Space instance)
67
  model = load_and_prepare_model()
68
 
69
+
70
+ # -----------------------
71
+ # Utils
72
+ # -----------------------
73
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
74
  if randomize_seed:
75
  seed = random.randint(0, np.iinfo(np.int32).max)
76
+ return int(seed)
77
+
78
 
79
+ def save_image(img: Image.Image) -> str:
80
+ # Save to the working dir so HF can expose it as an artifact
81
+ unique_name = f"{uuid.uuid4().hex}.png"
82
  img.save(unique_name)
83
  return unique_name
84
 
85
+
86
+ # -----------------------
87
+ # Generation
88
+ # -----------------------
89
  @spaces.GPU(duration=60, enable_queue=True)
90
  def generate(
91
+ prompt: str,
92
+ seed: int = 1,
93
+ width: int = 1024,
94
+ height: int = 1024,
95
+ guidance_scale: float = 3.0,
96
+ num_inference_steps: int = 25,
97
+ randomize_seed: bool = False,
98
  ):
99
  global model
100
 
101
+ # Guardrails
102
+ if not prompt or not prompt.strip():
103
+ raise gr.Error("Please enter a prompt.")
104
+
105
+ # SDXL prefers dims divisible by 8
106
+ width = int(max(256, (width // 8) * 8))
107
+ height = int(max(256, (height // 8) * 8))
108
+
109
+ seed = randomize_seed_fn(seed, randomize_seed)
110
  generator = torch.Generator(device=model.device).manual_seed(seed)
111
 
112
  positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
113
 
114
+ # NOTE: pass strings (not one-element lists)
115
+ images = model(
116
+ prompt=positive_prompt,
117
+ negative_prompt=negative_prompt,
118
+ width=width,
119
+ height=height,
120
+ guidance_scale=float(guidance_scale),
121
+ num_inference_steps=int(num_inference_steps),
122
+ generator=generator,
123
+ output_type="pil",
124
+ ).images
125
+
126
+ image_path = save_image(images[0])
127
  return image_path
128
 
129
+
130
+ # -----------------------
131
+ # UI
132
+ # -----------------------
133
  with gr.Blocks(theme="soft") as demo:
 
134
  with gr.Row():
135
  with gr.Column(scale=12, elem_id="title_block"):
136
+ gr.Markdown(
137
+ "<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>"
138
+ )
139
+ gr.Markdown(
140
+ "<h2 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>Developed using RealVisXL_V5.0_Lightning model with ❤ by Aklavya</h2>"
141
+ )
142
+
143
  with gr.Row():
144
  with gr.Column(scale=3):
145
  prompt = gr.Textbox(
 
147
  placeholder="Describe the image you want to create",
148
  lines=2,
149
  )
150
+ seed = gr.Number(value=1, label="Seed", precision=0)
151
+ randomize_seed = gr.Checkbox(value=True, label="Randomize Seed")
152
+ width = gr.Slider(512, 1536, value=1024, step=8, label="Width")
153
+ height = gr.Slider(512, 1536, value=1024, step=8, label="Height")
154
+ guidance_scale = gr.Slider(1.0, 10.0, value=3.0, step=0.5, label="Guidance Scale")
155
+ steps = gr.Slider(10, 35, value=25, step=1, label="Inference Steps")
156
+
157
+ run_button = gr.Button("Generate Image", variant="primary")
158
 
 
159
  example_prompts_text = (
160
  "Dew-covered spider web in morning sunlight, with blurred greenery\n"
161
  "--------------------------------------------\n"
 
166
  "Autumn forest with golden leaves, sunlight through trees, and a breeze"
167
  )
168
 
169
+ gr.Textbox(
170
  value=example_prompts_text,
171
+ lines=8,
172
  label="Sample Inputs",
173
  interactive=False,
174
  )
 
178
  label="Generated Image",
179
  type="filepath",
180
  elem_id="output_image",
 
181
  )
182
 
183
  run_button.click(
184
  fn=generate,
185
+ inputs=[prompt, seed, width, height, guidance_scale, steps, randomize_seed],
186
  outputs=[result_image],
187
+ api_name="generate",
188
  )
189
 
190
+ if __name__ == "__main__":
191
+ demo.launch()