Aklavya commited on
Commit
8a40f1d
·
verified ·
1 Parent(s): c24d164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -89
app.py CHANGED
@@ -4,15 +4,11 @@ import uuid
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
 
7
  import torch
8
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, FluxPipeline
9
- from huggingface_hub import snapshot_download
10
  from typing import Tuple
11
 
12
- # Ensure Hugging Face token from secrets
13
- HF_TOKEN = os.getenv("HF_TOKEN")
14
-
15
- # Function to apply the style based on the selected model
16
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
17
  styles = {
18
  "3840 x 2160": (
@@ -26,129 +22,99 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
26
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
27
  return p.replace("{prompt}", positive), n + negative
28
 
29
- # Function to load and prepare the model
30
- def load_and_prepare_model(model_name: str):
31
- if model_name == "RealVisXL_V5.0_Lightning":
32
- model_id = "SG161222/RealVisXL_V5.0_Lightning"
33
- # Ensure the model is downloaded locally
34
- local_model_path = snapshot_download(model_id, token=HF_TOKEN)
35
- pipe = StableDiffusionXLPipeline.from_pretrained(
36
- local_model_path,
37
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
- use_safetensors=True,
39
- add_watermarker=False,
40
- ).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
41
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
42
- elif model_name == "FLUX.1-dev":
43
- model_id = "black-forest-labs/FLUX.1-dev"
44
- # Ensure the model is downloaded locally
45
- local_model_path = snapshot_download(model_id, token=HF_TOKEN)
46
- pipe = FluxPipeline.from_pretrained(
47
- local_model_path,
48
- torch_dtype=torch.bfloat16,
49
- )
50
- pipe.enable_model_cpu_offload() # Save VRAM by offloading model to CPU
51
- else:
52
- raise ValueError("Unsupported model")
53
-
54
  return pipe
55
 
56
- # Function to save image
 
 
 
 
 
 
57
  def save_image(img):
58
  unique_name = str(uuid.uuid4()) + ".png"
59
  img.save(unique_name)
60
  return unique_name
61
 
62
- # Main image generation function
63
  def generate(
64
  prompt: str,
65
- model_name: str,
66
  seed: int = 1,
67
  width: int = 1024,
68
  height: int = 1024,
69
  guidance_scale: float = 3,
70
- num_inference_steps: int = 50,
71
  randomize_seed: bool = False,
72
  ):
73
- model = load_and_prepare_model(model_name)
74
- seed = random.randint(0, np.iinfo(np.int32).max) if randomize_seed else seed
75
- generator = torch.Generator("cpu" if model_name == "FLUX.1-dev" else model.device).manual_seed(seed)
76
 
77
- if model_name == "RealVisXL_V5.0_Lightning":
78
- positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
79
- options = {
80
- "prompt": [positive_prompt],
81
- "negative_prompt": [negative_prompt],
82
- "width": width,
83
- "height": height,
84
- "guidance_scale": guidance_scale,
85
- "num_inference_steps": num_inference_steps,
86
- "generator": generator,
87
- "output_type": "pil",
88
- }
89
- images = model(**options).images
90
- elif model_name == "FLUX.1-dev":
91
- image = model(
92
- prompt=prompt,
93
- height=height,
94
- width=width,
95
- guidance_scale=guidance_scale,
96
- num_inference_steps=num_inference_steps,
97
- max_sequence_length=512,
98
- generator=generator
99
- ).images[0]
100
- images = [image]
101
 
102
- image_paths = [save_image(img) for img in images]
103
- return image_paths, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Gradio interface setup
106
  with gr.Blocks(theme="soft") as demo:
107
  # Centered text "SNAPSCRIBE" at the top of the screen
108
  gr.Markdown("<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>")
109
 
110
- # Dropdown for model selection
111
  with gr.Row():
112
  with gr.Column(scale=3):
113
- model_dropdown = gr.Dropdown(
114
- choices=["RealVisXL_V5.0_Lightning", "FLUX.1-dev"],
115
- label="Select Model",
116
- value="RealVisXL_V5.0_Lightning"
117
- )
118
  prompt = gr.Textbox(
119
  label="Input Prompt",
120
  placeholder="Describe the image you want to create",
121
  lines=2,
122
  )
123
  run_button = gr.Button("Generate Image")
 
124
  with gr.Column(scale=7):
125
- result = gr.Gallery(label="Generated Image", columns=2)
126
 
127
  run_button.click(
128
  fn=generate,
129
- inputs=[prompt, model_dropdown],
130
- outputs=[result],
131
  )
132
 
133
- # Footer added to center-align the text
134
- gr.HTML("""
135
  <style>
136
- .footer {
137
- position: relative;
138
- left: 0;
139
- bottom: 0;
140
- width: 100%;
141
- background-color: white;
142
- color: black;
143
- text-align: center;
144
- padding: 10px;
145
- margin-top: 20px;
146
- }
147
  </style>
148
  <div class="footer">
149
- <p>Developed with ❤ by Aklavya(Bucky)</p>
150
  </div>
151
  """)
152
 
153
- # Launch the Gradio interface
154
  demo.launch()
 
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
+ import spaces
8
  import torch
9
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
10
  from typing import Tuple
11
 
 
 
 
 
12
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
13
  styles = {
14
  "3840 x 2160": (
 
22
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
23
  return p.replace("{prompt}", positive), n + negative
24
 
25
+ def load_and_prepare_model():
26
+ model_id = "SG161222/RealVisXL_V5.0_Lightning"
27
+ pipe = StableDiffusionXLPipeline.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
+ use_safetensors=True,
31
+ add_watermarker=False,
32
+ ).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
33
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return pipe
35
 
36
+ model = load_and_prepare_model()
37
+
38
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
39
+ if randomize_seed:
40
+ seed = random.randint(0, np.iinfo(np.int32).max)
41
+ return seed
42
+
43
  def save_image(img):
44
  unique_name = str(uuid.uuid4()) + ".png"
45
  img.save(unique_name)
46
  return unique_name
47
 
48
+ @spaces.GPU(duration=60, enable_queue=True)
49
  def generate(
50
  prompt: str,
 
51
  seed: int = 1,
52
  width: int = 1024,
53
  height: int = 1024,
54
  guidance_scale: float = 3,
55
+ num_inference_steps: int = 25,
56
  randomize_seed: bool = False,
57
  ):
58
+ global model
 
 
59
 
60
+ seed = int(randomize_seed_fn(seed, randomize_seed))
61
+ generator = torch.Generator(device=model.device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ positive_prompt, negative_prompt = apply_style("3840 x 2160", prompt)
64
+
65
+ options = {
66
+ "prompt": [positive_prompt],
67
+ "negative_prompt": [negative_prompt],
68
+ "width": width,
69
+ "height": height,
70
+ "guidance_scale": guidance_scale,
71
+ "num_inference_steps": num_inference_steps,
72
+ "generator": generator,
73
+ "output_type": "pil",
74
+ }
75
+
76
+ images = model(**options).images
77
+ image_path = save_image(images[0]) # Saving the first generated image
78
+ return image_path
79
 
 
80
  with gr.Blocks(theme="soft") as demo:
81
  # Centered text "SNAPSCRIBE" at the top of the screen
82
  gr.Markdown("<h1 style='text-align:center; color:white; font-weight:bold; text-decoration:underline;'>SNAPSCRIBE</h1>")
83
 
 
84
  with gr.Row():
85
  with gr.Column(scale=3):
 
 
 
 
 
86
  prompt = gr.Textbox(
87
  label="Input Prompt",
88
  placeholder="Describe the image you want to create",
89
  lines=2,
90
  )
91
  run_button = gr.Button("Generate Image")
92
+ gr.Markdown("Developed using the RealVisXL_V5.0_Lightning model.", elem_id="model_info")
93
  with gr.Column(scale=7):
94
+ result_image = gr.Image(label="Generated Image", type="filepath")
95
 
96
  run_button.click(
97
  fn=generate,
98
+ inputs=[prompt],
99
+ outputs=[result_image],
100
  )
101
 
102
+ # Footer with custom style and text
103
+ gr.Markdown("""
104
  <style>
105
+ .footer {
106
+ position: relative;
107
+ left: 0;
108
+ bottom: 0;
109
+ width: 100%;
110
+ background-color: white;
111
+ color: black;
112
+ text-align: center;
113
+ }
 
 
114
  </style>
115
  <div class="footer">
116
+ <p>Developed with ❤ by Aklavya(Bucky)</p>
117
  </div>
118
  """)
119
 
 
120
  demo.launch()