Leteint commited on
Commit
b02b56c
·
verified ·
1 Parent(s): 2c061b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -4
app.py CHANGED
@@ -1,7 +1,266 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import base64
3
+ import io
4
+ import json
5
+ from PIL import Image, PngImagePlugin
6
  import gradio as gr
7
 
8
+ # URL de ton serveur AUTOMATIC1111 (à adapter)
9
+ A1111_URL = "http://127.0.0.1:7860"
10
 
11
+ # Nom du LoRA tel qu’il apparaît dans A1111 (onglet Extra Networks > Lora)
12
+ # ex : "HighSchoolDxD_RiasGremoryXL" si c’est ainsi qu’il est nommé
13
+ RIAS_LORA_NAME = "HighSchoolDxD_RiasGremory"
14
+
15
+ # Triggers Rias
16
+ RIAS_TRIGGERS = (
17
+ "RiasGremory, red hair, long hair, blue eyes, ahoge, "
18
+ "school uniform, striped shirt, neck ribbon, black ribbon, "
19
+ "black corset, purple skirt, huge breasts"
20
+ )
21
+
22
+
23
+ def build_prompt_with_lora(user_prompt: str,
24
+ enable_rias: bool,
25
+ lora_weight: float) -> str:
26
+ """
27
+ Construit le prompt final en ajoutant éventuellement le LoRA Rias
28
+ + ses trigger words.
29
+ """
30
+
31
+ user_prompt = (user_prompt or "").strip()
32
+
33
+ if not enable_rias:
34
+ return user_prompt
35
+
36
+ # Syntaxe A1111 pour LoRA : <lora:Nom:poids>
37
+ lora_tag = f"<lora:{RIAS_LORA_NAME}:{lora_weight:.2f}>"
38
+ parts = [lora_tag, RIAS_TRIGGERS, user_prompt]
39
+ return ", ".join([p for p in parts if p])
40
+
41
+
42
+ def txt2img_sdxl(
43
+ prompt: str,
44
+ negative_prompt: str,
45
+ seed: float,
46
+ steps: float,
47
+ cfg_scale: float,
48
+ width: float,
49
+ height: float,
50
+ sampler_name: str,
51
+ enable_rias_lora: bool,
52
+ rias_weight: float,
53
+ script_name: str,
54
+ ):
55
+ """
56
+ Fonction appelée par Gradio : envoie une requête txt2img à A1111,
57
+ récupère l'image, et sauvegarde les paramètres dans les métadonnées PNG
58
+ + un JSON affiché dans l'UI.
59
+ """
60
+
61
+ # Prompt final
62
+ final_prompt = build_prompt_with_lora(
63
+ prompt,
64
+ enable_rias_lora,
65
+ float(rias_weight),
66
+ )
67
+
68
+ # Seed : -1 ou vide => seed aléatoire côté A1111
69
+ try:
70
+ seed_int = int(seed)
71
+ except Exception:
72
+ seed_int = -1
73
+
74
+ payload = {
75
+ "prompt": final_prompt,
76
+ "negative_prompt": negative_prompt or "",
77
+ "seed": seed_int,
78
+ "steps": int(steps),
79
+ "cfg_scale": float(cfg_scale),
80
+ "width": int(width),
81
+ "height": int(height),
82
+ "sampler_name": sampler_name,
83
+ # Important pour SDXL : A1111 gère le reste côté modèle (refiner, etc.)
84
+ # Tu peux ajouter enable_hr, hr_scale, etc. plus tard.
85
+ }
86
+
87
+ # Appel API A1111
88
+ try:
89
+ resp = requests.post(
90
+ f"{A1111_URL}/sdapi/v1/txt2img",
91
+ json=payload,
92
+ timeout=600,
93
+ )
94
+ resp.raise_for_status()
95
+ except Exception as e:
96
+ return None, f"Erreur appel API A1111 : {e}", ""
97
+
98
+ data = resp.json()
99
+
100
+ if "images" not in data or not data["images"]:
101
+ return None, "Réponse API sans image.", json.dumps(data, ensure_ascii=False, indent=2)
102
+
103
+ # Première image
104
+ b64_image = data["images"][0]
105
+ # certaines configs renvoient "data:image/png;base64,....."
106
+ if "," in b64_image:
107
+ b64_image = b64_image.split(",", 1)[1]
108
+
109
+ img_bytes = base64.b64decode(b64_image)
110
+ img = Image.open(io.BytesIO(img_bytes))
111
+
112
+ # Récup paramètres renvoyés par A1111, si dispo
113
+ # (facultatif, mais pratique pour débogage)
114
+ a1111_params = data.get("parameters", {})
115
+
116
+ # Métadonnées à stocker
117
+ meta = {
118
+ "app_prompt": prompt,
119
+ "app_negative_prompt": negative_prompt,
120
+ "final_prompt_sent_to_a1111": final_prompt,
121
+ "seed": seed_int,
122
+ "steps": int(steps),
123
+ "cfg_scale": float(cfg_scale),
124
+ "width": int(width),
125
+ "height": int(height),
126
+ "sampler_name": sampler_name,
127
+ "enable_rias_lora": bool(enable_rias_lora),
128
+ "rias_lora_name": RIAS_LORA_NAME if enable_rias_lora else None,
129
+ "rias_lora_weight": float(rias_weight) if enable_rias_lora else None,
130
+ "a1111_parameters": a1111_params,
131
+ }
132
+
133
+ # Nom de fichier pour sauvegarde locale dans le Space (utile pour debug)
134
+ safe_name = (script_name or "sdxl_generation").strip().replace(" ", "_")
135
+ filename = f"{safe_name}.png"
136
+
137
+ # Mettre le JSON dans les métadonnées PNG
138
+ pnginfo = PngImagePlugin.PngInfo()
139
+ pnginfo.add_text("generation_params", json.dumps(meta, ensure_ascii=False))
140
+
141
+ img.save(filename, pnginfo=pnginfo)
142
+
143
+ # Texte affiché dans la zone "script" (JSON lisible)
144
+ script_txt = json.dumps(meta, ensure_ascii=False, indent=2)
145
+
146
+ return img, script_txt, filename
147
+
148
+
149
+ # ---------------- Gradio UI ----------------
150
+
151
+ with gr.Blocks(title="SDXL + LoRA (A1111 client)") as demo:
152
+ gr.Markdown(
153
+ "## SDXL (Automatic1111) + LoRA Rias \n"
154
+ "Ce Space envoie les requêtes à un serveur Automatic1111 déjà configuré en SDXL."
155
+ )
156
+
157
+ with gr.Row():
158
+ with gr.Column():
159
+ prompt = gr.Textbox(
160
+ label="Prompt",
161
+ lines=4,
162
+ placeholder="Décris la scène (SDXL + LoRA Rias si activée)...",
163
+ )
164
+ negative = gr.Textbox(
165
+ label="Negative prompt",
166
+ lines=3,
167
+ value="",
168
+ )
169
+
170
+ seed = gr.Number(
171
+ label="Seed (-1 = random)",
172
+ value=-1,
173
+ )
174
+ steps = gr.Slider(
175
+ minimum=10,
176
+ maximum=60,
177
+ value=30,
178
+ step=1,
179
+ label="Steps",
180
+ )
181
+ cfg_scale = gr.Slider(
182
+ minimum=1.0,
183
+ maximum=20.0,
184
+ value=7.0,
185
+ step=0.5,
186
+ label="CFG scale",
187
+ )
188
+
189
+ width = gr.Slider(
190
+ minimum=512,
191
+ maximum=1536,
192
+ value=1024,
193
+ step=64,
194
+ label="Width (SDXL natif 1024)",
195
+ )
196
+ height = gr.Slider(
197
+ minimum=512,
198
+ maximum=1536,
199
+ value=1024,
200
+ step=64,
201
+ label="Height (SDXL natif 1024)",
202
+ )
203
+
204
+ sampler_name = gr.Dropdown(
205
+ choices=[
206
+ "Euler a",
207
+ "Euler",
208
+ "DPM++ 2M Karras",
209
+ "DPM++ SDE Karras",
210
+ ],
211
+ value="DPM++ 2M Karras",
212
+ label="Sampler",
213
+ )
214
+
215
+ enable_rias_lora = gr.Checkbox(
216
+ label="Activer LoRA RiasGremory",
217
+ value=True,
218
+ )
219
+ rias_weight = gr.Slider(
220
+ minimum=0.1,
221
+ maximum=1.5,
222
+ value=0.9,
223
+ step=0.05,
224
+ label="Poids LoRA Rias",
225
+ )
226
+
227
+ script_name = gr.Textbox(
228
+ label="Nom base pour l'image / script",
229
+ value="rias_sdxl_generation",
230
+ )
231
+
232
+ run_btn = gr.Button("Générer", variant="primary")
233
+
234
+ with gr.Column():
235
+ out_img = gr.Image(
236
+ label="Image générée (SDXL via A1111)",
237
+ show_download=True,
238
+ )
239
+ out_script = gr.Textbox(
240
+ label="Script (JSON des paramètres + payload A1111)",
241
+ lines=20,
242
+ )
243
+ out_file = gr.File(
244
+ label="Fichier PNG (avec métadonnées) généré côté Space",
245
+ )
246
+
247
+ run_btn.click(
248
+ fn=txt2img_sdxl,
249
+ inputs=[
250
+ prompt,
251
+ negative,
252
+ seed,
253
+ steps,
254
+ cfg_scale,
255
+ width,
256
+ height,
257
+ sampler_name,
258
+ enable_rias_lora,
259
+ rias_weight,
260
+ script_name,
261
+ ],
262
+ outputs=[out_img, out_script, out_file],
263
+ )
264
+
265
+ if __name__ == "__main__":
266
+ demo.launch()