Aklavya commited on
Commit
35a6af1
·
verified ·
1 Parent(s): 6fae227

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -49
app.py CHANGED
@@ -1,19 +1,22 @@
1
  import os
2
  import uuid
3
  import random
4
- from typing import Tuple
5
 
6
  import gradio as gr
7
  import numpy as np
8
  from PIL import Image
9
  import torch
10
  import spaces
11
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
 
 
12
 
 
 
13
 
14
- # -----------------------
15
- # Style handling
16
- # -----------------------
17
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
18
  styles = {
19
  "3840 x 2160": (
@@ -25,67 +28,85 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
25
  }
26
  DEFAULT_STYLE_NAME = "3840 x 2160"
27
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
28
- return p.replace("{prompt}", positive), (n + (" " + negative if negative else "")).strip()
 
29
 
 
 
 
 
30
 
31
- # -----------------------
32
- # Model loader
33
- # -----------------------
34
- def load_and_prepare_model() -> StableDiffusionXLPipeline:
35
- model_id = "SG161222/RealVisXL_V5.0_Lightning"
36
  use_cuda = torch.cuda.is_available()
37
  dtype = torch.float16 if use_cuda else torch.float32
38
- device = torch.device("cuda:0" if use_cuda else "cpu")
39
-
40
  pipe = StableDiffusionXLPipeline.from_pretrained(
41
- model_id,
42
  torch_dtype=dtype,
43
  use_safetensors=True,
44
  add_watermarker=False,
 
45
  )
46
-
47
- # Use a stable, fast scheduler
48
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
49
-
50
- # Memory/throughput optimizations
51
  try:
52
  pipe.enable_xformers_memory_efficient_attention()
53
  except Exception:
54
- # xformers may be missing on CPU or some build types — that's fine
55
  pass
 
 
 
56
 
57
- # Optional, safe perf knobs on CUDA
58
- if use_cuda:
59
- torch.backends.cuda.matmul.allow_tf32 = True
60
- torch.set_grad_enabled(False)
61
-
 
 
 
 
 
 
 
 
 
62
  pipe = pipe.to(device)
63
  return pipe
64
 
 
 
 
65
 
66
- # Global model (loaded once per Space instance)
67
- model = load_and_prepare_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
69
 
70
- # -----------------------
71
- # Utils
72
- # -----------------------
73
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
74
  if randomize_seed:
75
  seed = random.randint(0, np.iinfo(np.int32).max)
76
  return int(seed)
77
 
78
-
79
  def save_image(img: Image.Image) -> str:
80
- # Save to the working dir so HF can expose it as an artifact
81
  unique_name = f"{uuid.uuid4().hex}.png"
82
  img.save(unique_name)
83
  return unique_name
84
 
85
-
86
- # -----------------------
87
- # Generation
88
- # -----------------------
89
  @spaces.GPU(duration=60, enable_queue=True)
90
  def generate(
91
  prompt: str,
@@ -96,40 +117,36 @@ def generate(
96
  num_inference_steps: int = 25,
97
  randomize_seed: bool = False,
98
  ):
99
- global model
100
-
101
- # Guardrails
102
  if not prompt or not prompt.strip():
103
  raise gr.Error("Please enter a prompt.")
104
 
105
- # SDXL prefers dims divisible by 8
106
  width = int(max(256, (width // 8) * 8))
107
  height = int(max(256, (height // 8) * 8))
108
 
109
  seed = randomize_seed_fn(seed, randomize_seed)
110
  generator = torch.Generator(device=model.device).manual_seed(seed)
111
 
 
112
  positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
113
 
114
- # NOTE: pass strings (not one-element lists)
115
- images = model(
116
  prompt=positive_prompt,
117
  negative_prompt=negative_prompt,
118
- width=width,
119
- height=height,
120
  guidance_scale=float(guidance_scale),
121
  num_inference_steps=int(num_inference_steps),
122
  generator=generator,
123
  output_type="pil",
124
- ).images
125
 
 
 
126
  image_path = save_image(images[0])
127
  return image_path
128
 
129
-
130
- # -----------------------
131
- # UI
132
- # -----------------------
133
  with gr.Blocks(theme="soft") as demo:
134
  with gr.Row():
135
  with gr.Column(scale=12, elem_id="title_block"):
@@ -137,7 +154,7 @@ with gr.Blocks(theme="soft") as demo:
137
  "<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>"
138
  )
139
  gr.Markdown(
140
- "<h2 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>Developed using RealVisXL_V5.0_Lightning model with ❤ by Aklavya</h2>"
141
  )
142
 
143
  with gr.Row():
 
1
  import os
2
  import uuid
3
  import random
4
+ from typing import Tuple, Optional
5
 
6
  import gradio as gr
7
  import numpy as np
8
  from PIL import Image
9
  import torch
10
  import spaces
11
+ from diffusers import (
12
+ StableDiffusionXLPipeline,
13
+ StableDiffusionPipeline,
14
+ EulerAncestralDiscreteScheduler,
15
+ )
16
 
17
+ PRIMARY_MODEL_ID = "SG161222/RealVisXL_V5.0_Lightning" # requires access + token
18
+ FALLBACK_MODEL_ID = "stabilityai/sd-turbo" # public, fast 1.5-turbo
19
 
 
 
 
20
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
21
  styles = {
22
  "3840 x 2160": (
 
28
  }
29
  DEFAULT_STYLE_NAME = "3840 x 2160"
30
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
31
+ n2 = (n + (" " + negative if negative else "")).strip()
32
+ return p.replace("{prompt}", positive), n2
33
 
34
+ def _enable_performance_knobs():
35
+ if torch.cuda.is_available():
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
+ torch.set_grad_enabled(False)
38
 
39
+ def _try_load_realvis(hf_token: Optional[str]):
 
 
 
 
40
  use_cuda = torch.cuda.is_available()
41
  dtype = torch.float16 if use_cuda else torch.float32
 
 
42
  pipe = StableDiffusionXLPipeline.from_pretrained(
43
+ PRIMARY_MODEL_ID,
44
  torch_dtype=dtype,
45
  use_safetensors=True,
46
  add_watermarker=False,
47
+ token=hf_token, # <- IMPORTANT
48
  )
 
 
49
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
 
50
  try:
51
  pipe.enable_xformers_memory_efficient_attention()
52
  except Exception:
 
53
  pass
54
+ device = torch.device("cuda:0" if use_cuda else "cpu")
55
+ pipe = pipe.to(device)
56
+ return pipe
57
 
58
+ def _try_load_fallback():
59
+ # sd-turbo is Stable Diffusion 1.5 turbo; quick & public
60
+ use_cuda = torch.cuda.is_available()
61
+ dtype = torch.float16 if use_cuda else torch.float32
62
+ pipe = StableDiffusionPipeline.from_pretrained(
63
+ FALLBACK_MODEL_ID,
64
+ torch_dtype=dtype,
65
+ use_safetensors=True,
66
+ )
67
+ try:
68
+ pipe.enable_xformers_memory_efficient_attention()
69
+ except Exception:
70
+ pass
71
+ device = torch.device("cuda:0" if use_cuda else "cpu")
72
  pipe = pipe.to(device)
73
  return pipe
74
 
75
+ def load_and_prepare_model():
76
+ _enable_performance_knobs()
77
+ hf_token = os.getenv("HF_TOKEN", "").strip() or None
78
 
79
+ # Try RealVis first
80
+ try:
81
+ return _try_load_realvis(hf_token)
82
+ except Exception as e:
83
+ msg = str(e).lower()
84
+ if ("401" in msg or "403" in msg or "unauthorized" in msg or "forbidden" in msg):
85
+ # Clear hint in server logs; UI will still work via fallback.
86
+ print(
87
+ "\n[WARNING] Could not load RealVisXL (auth). "
88
+ "Make sure you've requested access and set HF_TOKEN in Space secrets.\n"
89
+ )
90
+ else:
91
+ print(f"\n[WARNING] RealVisXL failed to load: {e}\n")
92
+
93
+ # Fallback to sd-turbo so app still runs
94
+ print("[INFO] Falling back to stabilityai/sd-turbo (public).")
95
+ return _try_load_fallback()
96
 
97
+ # Load once
98
+ model = load_and_prepare_model()
99
 
 
 
 
100
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
101
  if randomize_seed:
102
  seed = random.randint(0, np.iinfo(np.int32).max)
103
  return int(seed)
104
 
 
105
  def save_image(img: Image.Image) -> str:
 
106
  unique_name = f"{uuid.uuid4().hex}.png"
107
  img.save(unique_name)
108
  return unique_name
109
 
 
 
 
 
110
  @spaces.GPU(duration=60, enable_queue=True)
111
  def generate(
112
  prompt: str,
 
117
  num_inference_steps: int = 25,
118
  randomize_seed: bool = False,
119
  ):
 
 
 
120
  if not prompt or not prompt.strip():
121
  raise gr.Error("Please enter a prompt.")
122
 
123
+ # Make dimensions friendly for SD models
124
  width = int(max(256, (width // 8) * 8))
125
  height = int(max(256, (height // 8) * 8))
126
 
127
  seed = randomize_seed_fn(seed, randomize_seed)
128
  generator = torch.Generator(device=model.device).manual_seed(seed)
129
 
130
+ # If model is SDXL pipeline, use the styled prompts; if fallback SD1.5 turbo, style still OK
131
  positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
132
 
133
+ # Some pipelines (sd-turbo) ignore guidance/steps or behave differently; passing is still safe
134
+ out = model(
135
  prompt=positive_prompt,
136
  negative_prompt=negative_prompt,
137
+ width=width if "xl" in model.__class__.__name__.lower() else None,
138
+ height=height if "xl" in model.__class__.__name__.lower() else None,
139
  guidance_scale=float(guidance_scale),
140
  num_inference_steps=int(num_inference_steps),
141
  generator=generator,
142
  output_type="pil",
143
+ )
144
 
145
+ # Handle both diffusers return shapes
146
+ images = getattr(out, "images", out)
147
  image_path = save_image(images[0])
148
  return image_path
149
 
 
 
 
 
150
  with gr.Blocks(theme="soft") as demo:
151
  with gr.Row():
152
  with gr.Column(scale=12, elem_id="title_block"):
 
154
  "<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>"
155
  )
156
  gr.Markdown(
157
+ "<h2 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>Developed with ❤ by Aklavya</h2>"
158
  )
159
 
160
  with gr.Row():