Leteint commited on
Commit
8fc7d36
·
verified ·
1 Parent(s): 9f10e87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -192
app.py CHANGED
@@ -3,21 +3,14 @@ import json
3
  import os
4
  import torch
5
  import gradio as gr
6
-
7
  from diffusers import AutoPipelineForText2Image
8
  from PIL import PngImagePlugin
9
- from huggingface_hub import hf_hub_download
10
 
11
  BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
12
 
13
- # LoRA cuerpo hourglass (Muapi)
14
- LORA_HOURGLASS_REPO = "Muapi/hourglass-body-shape-sd1.5-sdxl-pony-flux-olaz"
15
- LORA_HOURGLASS_FILE = "hourglass-body-shape-sd1.5-sdxl-pony-flux-olaz.safetensors"
16
- LORA_HOURGLASS_ADAPTER = "hourglass_body"
17
-
18
- # LoRA visages / détails (Stable Yogi)
19
  LORA_FACE_REPO = "akash-guptag/Detailers_By_Stable_Yogi"
20
- LORA_FACE_ADAPTER = "face_helper"
21
 
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
@@ -31,37 +24,11 @@ pipe = AutoPipelineForText2Image.from_pretrained(
31
  pipe.to(DEVICE)
32
  pipe.set_progress_bar_config(disable=True)
33
 
34
- # Téléchargement explicite du LoRA hourglass (debug + cache)
35
- try:
36
- local_path = hf_hub_download(
37
- repo_id=LORA_HOURGLASS_REPO,
38
- filename=LORA_HOURGLASS_FILE,
39
- local_dir="./lora_cache"
40
- )
41
- print(f"Fichier LoRA hourglass OK : {local_path}")
42
- except Exception as e:
43
- print(f"Erreur repo/fichier hourglass : {e}")
44
-
45
- print("Chargement LoRA hourglass body...")
46
- pipe.load_lora_weights(
47
- LORA_HOURGLASS_REPO,
48
- weight_name=LORA_HOURGLASS_FILE,
49
- adapter_name=LORA_HOURGLASS_ADAPTER,
50
- low_cpu_mem_usage=False, # important pour certains LoRA hybrides
51
- )
52
-
53
  print("Chargement LoRA face/detail (Stable Yogi)...")
54
- pipe.load_lora_weights(
55
- LORA_FACE_REPO,
56
- adapter_name=LORA_FACE_ADAPTER,
57
- )
58
 
59
  os.makedirs("outputs", exist_ok=True)
60
 
61
-
62
- # -----------------------
63
- # Fonction de génération
64
- # -----------------------
65
  @spaces.GPU()
66
  def generate(
67
  prompt: str,
@@ -73,32 +40,16 @@ def generate(
73
  height: float,
74
  face_enabled: bool,
75
  face_weight: float,
76
- hourglass_enabled: bool,
77
- hourglass_weight: float,
78
  script_name: str,
79
  ):
80
  if not prompt:
81
  return None, "Prompt vide.", ""
82
 
83
- # Seed
84
  seed_int = int(seed) if seed is not None else 0
85
- if seed_int < 0:
86
- generator = torch.Generator(device=DEVICE)
87
- else:
88
- generator = torch.Generator(device=DEVICE).manual_seed(seed_int)
89
-
90
- # Gestion des LoRA : liste d’adapters + poids
91
- adapters = []
92
- weights = []
93
-
94
- # LoRA visages / détails
95
- adapters.append(LORA_FACE_ADAPTER)
96
- weights.append(float(face_weight) if face_enabled else 0.0)
97
-
98
- # LoRA hourglass
99
- adapters.append(LORA_HOURGLASS_ADAPTER)
100
- weights.append(float(hourglass_weight) if hourglass_enabled else 0.0)
101
 
 
 
102
  pipe.set_adapters(adapters, adapter_weights=weights)
103
 
104
  try:
@@ -112,24 +63,16 @@ def generate(
112
  generator=generator,
113
  )
114
  except Exception as e:
115
- err_txt = f"Erreur pendant la génération SDXL : {repr(e)}"
116
- return None, err_txt, ""
117
 
118
  image = result.images[0]
119
-
120
  metadata = {
121
  "base_model": BASE_MODEL_ID,
122
  "face_lora_repo": LORA_FACE_REPO,
123
- "face_lora_adapter": LORA_FACE_ADAPTER,
124
  "face_enabled": bool(face_enabled),
125
  "face_weight": float(face_weight) if face_enabled else 0.0,
126
- "hourglass_lora_repo": LORA_HOURGLASS_REPO,
127
- "hourglass_lora_file": LORA_HOURGLASS_FILE,
128
- "hourglass_lora_adapter": LORA_HOURGLASS_ADAPTER,
129
- "hourglass_enabled": bool(hourglass_enabled),
130
- "hourglass_weight": float(hourglass_weight) if hourglass_enabled else 0.0,
131
  "prompt": prompt,
132
- "negative_prompt": negative_prompt,
133
  "seed": seed_int,
134
  "steps": int(steps),
135
  "guidance_scale": float(guidance),
@@ -137,10 +80,9 @@ def generate(
137
  "height": int(height),
138
  }
139
 
140
- base_name = script_name.strip() or "sdxl_lora_generation"
141
- safe_name = base_name.replace(" ", "_")
142
- img_path = os.path.join("outputs", f"{safe_name}.png")
143
- json_path = os.path.join("outputs", f"{safe_name}.json")
144
 
145
  pnginfo = PngImagePlugin.PngInfo()
146
  pnginfo.add_text("generation_params", json.dumps(metadata, ensure_ascii=False))
@@ -149,133 +91,33 @@ def generate(
149
  with open(json_path, "w", encoding="utf-8") as f:
150
  json.dump(metadata, f, ensure_ascii=False, indent=2)
151
 
152
- script_txt = json.dumps(metadata, ensure_ascii=False, indent=2)
153
-
154
- return image, script_txt, json_path
155
-
156
-
157
- # -----------------------
158
- # UI Gradio
159
- # -----------------------
160
- with gr.Blocks(title="SDXL + LoRA (Diffusers)") as demo:
161
- gr.Markdown(
162
- "## SDXL base + LoRA SDXL (Diffusers)\n"
163
- "Génération SDXL 1.0 avec LoRA detailer + hourglass body."
164
- )
165
 
 
 
 
 
166
  with gr.Row():
167
  with gr.Column():
168
- prompt = gr.Textbox(
169
- label="Prompt",
170
- lines=4,
171
- placeholder="Décris la scène à générer avec SDXL...",
172
- value=(
173
- "masterpiece, best quality, 1girl, looking at viewer, "
174
- "sharp face, detailed skin, realistic eyes, hourglass figure, slim waist, wide hips"
175
- ),
176
- )
177
- negative = gr.Textbox(
178
- label="Negative prompt",
179
- lines=3,
180
- value="blurry face, deformed face, bad anatomy, bad proportions, low quality",
181
- )
182
-
183
- seed = gr.Number(
184
- label="Seed (-1 = random)",
185
- value=-1,
186
- )
187
- steps = gr.Slider(
188
- minimum=10,
189
- maximum=60,
190
- value=30,
191
- step=1,
192
- label="Steps",
193
- )
194
- guidance = gr.Slider(
195
- minimum=1.0,
196
- maximum=20.0,
197
- value=7.0,
198
- step=0.5,
199
- label="CFG / Guidance scale",
200
- )
201
-
202
- width = gr.Slider(
203
- minimum=512,
204
- maximum=1536,
205
- value=1024,
206
- step=64,
207
- label="Width (SDXL natif 1024)",
208
- )
209
- height = gr.Slider(
210
- minimum=512,
211
- maximum=1536,
212
- value=1024,
213
- step=64,
214
- label="Height (SDXL natif 1024)",
215
- )
216
-
217
- face_enabled = gr.Checkbox(
218
- label="Activer LoRA face/detail (Stable Yogi)",
219
- value=True,
220
- )
221
- face_weight = gr.Slider(
222
- minimum=0.0,
223
- maximum=1.5,
224
- value=0.8, # recommandé <= 0.8 pour ce LoRA[web:60][web:63]
225
- step=0.05,
226
- label="Force LoRA face/detail",
227
- )
228
-
229
- hourglass_enabled = gr.Checkbox(
230
- label="Activer LoRA hourglass body shape",
231
- value=True,
232
- )
233
-
234
- hourglass_weight = gr.Slider(
235
- minimum=0.0,
236
- maximum=2.0,
237
- value=0.9,
238
- step=0.05,
239
- label="Force LoRA hourglass body",
240
- )
241
-
242
- script_name = gr.Textbox(
243
- label="Nom base pour l'image / script",
244
- value="sdxl_lora_generation",
245
- )
246
-
247
- run_btn = gr.Button("Générer", variant="primary")
248
-
249
- with gr.Column():
250
- out_img = gr.Image(
251
- label="Image générée (SDXL)",
252
- )
253
- out_script = gr.Textbox(
254
- label="Script (JSON des paramètres)",
255
- lines=20,
256
  )
257
- out_file = gr.File(
258
- label="Fichier JSON des paramètres (téléchargeable)",
259
- )
260
-
261
- run_btn.click(
262
- fn=generate,
263
- inputs=[
264
- prompt,
265
- negative,
266
- seed,
267
- steps,
268
- guidance,
269
- width,
270
- height,
271
- face_enabled,
272
- face_weight,
273
- hourglass_enabled,
274
- hourglass_weight,
275
- script_name,
276
- ],
277
- outputs=[out_img, out_script, out_file],
278
- )
279
 
280
  if __name__ == "__main__":
281
  demo.launch()
 
3
  import os
4
  import torch
5
  import gradio as gr
 
6
  from diffusers import AutoPipelineForText2Image
7
  from PIL import PngImagePlugin
 
8
 
9
  BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
10
 
11
+ # SEULEMENT LoRA face/detail compatible
 
 
 
 
 
12
  LORA_FACE_REPO = "akash-guptag/Detailers_By_Stable_Yogi"
13
+ LORA_FACE_ADAPTER = "face_detail"
14
 
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
24
  pipe.to(DEVICE)
25
  pipe.set_progress_bar_config(disable=True)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  print("Chargement LoRA face/detail (Stable Yogi)...")
28
+ pipe.load_lora_weights(LORA_FACE_REPO, adapter_name=LORA_FACE_ADAPTER)
 
 
 
29
 
30
  os.makedirs("outputs", exist_ok=True)
31
 
 
 
 
 
32
  @spaces.GPU()
33
  def generate(
34
  prompt: str,
 
40
  height: float,
41
  face_enabled: bool,
42
  face_weight: float,
 
 
43
  script_name: str,
44
  ):
45
  if not prompt:
46
  return None, "Prompt vide.", ""
47
 
 
48
  seed_int = int(seed) if seed is not None else 0
49
+ generator = torch.Generator(device=DEVICE).manual_seed(seed_int) if seed_int >= 0 else torch.Generator(device=DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ adapters = [LORA_FACE_ADAPTER]
52
+ weights = [float(face_weight) if face_enabled else 0.0]
53
  pipe.set_adapters(adapters, adapter_weights=weights)
54
 
55
  try:
 
63
  generator=generator,
64
  )
65
  except Exception as e:
66
+ return None, f"Erreur SDXL : {repr(e)}", ""
 
67
 
68
  image = result.images[0]
69
+
70
  metadata = {
71
  "base_model": BASE_MODEL_ID,
72
  "face_lora_repo": LORA_FACE_REPO,
 
73
  "face_enabled": bool(face_enabled),
74
  "face_weight": float(face_weight) if face_enabled else 0.0,
 
 
 
 
 
75
  "prompt": prompt,
 
76
  "seed": seed_int,
77
  "steps": int(steps),
78
  "guidance_scale": float(guidance),
 
80
  "height": int(height),
81
  }
82
 
83
+ base_name = script_name.strip().replace(" ", "_") or "sdxl_detailer"
84
+ img_path = os.path.join("outputs", f"{base_name}.png")
85
+ json_path = os.path.join("outputs", f"{base_name}.json")
 
86
 
87
  pnginfo = PngImagePlugin.PngInfo()
88
  pnginfo.add_text("generation_params", json.dumps(metadata, ensure_ascii=False))
 
91
  with open(json_path, "w", encoding="utf-8") as f:
92
  json.dump(metadata, f, ensure_ascii=False, indent=2)
93
 
94
+ return image, json.dumps(metadata, ensure_ascii=False, indent=2), json_path
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # UI simplifiée
97
+ with gr.Blocks(title="SDXL + Face Detail LoRA") as demo:
98
+ gr.Markdown("## SDXL 1.0 + LoRA Detailer (Stable Yogi)")
99
+
100
  with gr.Row():
101
  with gr.Column():
102
+ prompt = gr.Textbox(label="Prompt", lines=4, value="masterpiece, 1girl, detailed face, sharp eyes, realistic skin")
103
+ negative = gr.Textbox(label="Negative", value="blurry, deformed, low quality")
104
+
105
+ seed = gr.Number(label="Seed (-1=random)", value=-1)
106
+ steps = gr.Slider(10, 60, 30, step=1, label="Steps")
107
+ guidance = gr.Slider(1.0, 20.0, 7.0, step=0.5, label="Guidance")
108
+
109
+ width = gr.Slider(512, 1536, 1024, step=64, label="Width")
110
+ height = gr.Slider(512, 1536, 1024, step=64, label="Height")
111
+
112
+ face_enabled = gr.Checkbox("Activer LoRA face/detail", value=True)
113
+ face_weight = gr.Slider(0.0, 1.5, 0.8, step=0.05, label="Poids LoRA")
114
+
115
+ script_name = gr.Textbox("Nom sortie", value="test_detailer")
116
+ gr.Button("Générer", variant="primary").click(
117
+ generate,
118
+ inputs=[prompt, negative, seed, steps, guidance, width, height, face_enabled, face_weight, script_name],
119
+ outputs=[gr.Image(label="Image"), gr.Textbox(label="JSON", lines=15), gr.File(label="Télécharger JSON")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
  demo.launch()