BATUTO-ART commited on
Commit
7ce094c
·
verified ·
1 Parent(s): d30b575

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +411 -0
app.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - BATUTO-ART MIX (Optimizado para HF Spaces - CPU)
2
+ import os
3
+ import io
4
+ import time
5
+ import random
6
+ import threading
7
+ import requests
8
+ from pathlib import Path
9
+ from typing import Optional
10
+ from PIL import Image
11
+ import gradio as gr
12
+
13
+ # Diffusers imports (algunas pueden no usarse en CPU si no están instaladas)
14
+ try:
15
+ import torch
16
+ from diffusers import FluxPipeline, StableDiffusionPipeline
17
+ except Exception as e:
18
+ # En HF Spaces instala diffusers + accelerate + transformers y torch a la versión adecuada
19
+ print("Warning: diffusers/torch no cargados:", e)
20
+
21
+ # -----------------------
22
+ # Config y rutas
23
+ # -----------------------
24
+ OUT_DIR = Path("/tmp/batuto_art_mix")
25
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
26
+
27
+ # Defaults 9:16
28
+ DEFAULT_WIDTH = 576
29
+ DEFAULT_HEIGHT = 1024
30
+
31
+ # Hugging Face token (si quieres usar from env para algunos modelos)
32
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
33
+
34
+ # REVE API defaults (si quieres usar la que mencionaste ajusta aquí)
35
+ REVE_API_URL = "https://api.reve.com/v1/image/create"
36
+
37
+ # -----------------------
38
+ # Cache de pipelines (global)
39
+ # -----------------------
40
+ _PIPELINE_CACHE = {
41
+ "flux": {},
42
+ "sd": None
43
+ }
44
+ _PIPELINE_LOCK = threading.Lock()
45
+
46
+ # -----------------------
47
+ # Utilities
48
+ # -----------------------
49
+ def save_image_to_file(img: Image.Image, prefix="batuto", fmt="png"):
50
+ ts = int(time.time() * 1000)
51
+ fn = OUT_DIR / f"{prefix}_{ts}.png"
52
+ img.save(fn, format="PNG")
53
+ return str(fn)
54
+
55
+ def pil_image_to_bytes(img: Image.Image):
56
+ bio = io.BytesIO()
57
+ img.save(bio, format="PNG")
58
+ bio.seek(0)
59
+ return bio
60
+
61
+ def random_seed():
62
+ return random.randint(0, 2**31 - 1)
63
+
64
+ # -----------------------
65
+ # Loaders: FLUX
66
+ # -----------------------
67
+ def load_flux_pipeline(model_id: str):
68
+ """
69
+ Carga (y cachea) un FluxPipeline para CPU.
70
+ Usa torch.float32 para compatibilidad CPU.
71
+ """
72
+ if "FluxPipeline" not in globals():
73
+ raise RuntimeError("diffusers FluxPipeline no está disponible en este entorno.")
74
+ with _PIPELINE_LOCK:
75
+ if model_id in _PIPELINE_CACHE["flux"]:
76
+ return _PIPELINE_CACHE["flux"][model_id]
77
+ # Intentar cargar
78
+ dtype = torch.float32
79
+ kwargs = {}
80
+ if HF_TOKEN:
81
+ kwargs["use_auth_token"] = HF_TOKEN
82
+ pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, **kwargs)
83
+ # Forzar CPU
84
+ pipe.to("cpu")
85
+ # Optimizaciones para CPU
86
+ try:
87
+ pipe.enable_sequential_cpu_offload()
88
+ except Exception:
89
+ pass
90
+ try:
91
+ pipe.enable_attention_slicing()
92
+ except Exception:
93
+ pass
94
+ _PIPELINE_CACHE["flux"][model_id] = pipe
95
+ return pipe
96
+
97
+ # -----------------------
98
+ # Loaders: Stable Diffusion (advertencia: pesado en CPU)
99
+ # -----------------------
100
+ def load_sd_pipeline(model_id: str = "runwayml/stable-diffusion-v1-5"):
101
+ if "StableDiffusionPipeline" not in globals():
102
+ raise RuntimeError("StableDiffusionPipeline no está disponible en este entorno.")
103
+ with _PIPELINE_LOCK:
104
+ if _PIPELINE_CACHE["sd"] is not None:
105
+ return _PIPELINE_CACHE["sd"]
106
+ dtype = torch.float32
107
+ kwargs = dict()
108
+ if HF_TOKEN:
109
+ kwargs["use_auth_token"] = HF_TOKEN
110
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, safety_checker=None, **kwargs)
111
+ pipe.to("cpu")
112
+ try:
113
+ pipe.enable_attention_slicing()
114
+ except Exception:
115
+ pass
116
+ _PIPELINE_CACHE["sd"] = pipe
117
+ return pipe
118
+
119
+ # -----------------------
120
+ # Generación: FLUX
121
+ # -----------------------
122
+ def generate_flux(model_id: str, prompt: str, negative_prompt: str,
123
+ steps: int, guidance: float, width: int, height: int,
124
+ seed: int):
125
+ if not prompt or prompt.strip() == "":
126
+ return None, None, "Prompt vacío"
127
+
128
+ try:
129
+ pipe = load_flux_pipeline(model_id)
130
+ except Exception as e:
131
+ return None, None, f"Error cargando FLUX: {e}"
132
+
133
+ if seed is None or int(seed) == -1:
134
+ seed = random_seed()
135
+ gen = torch.Generator(device="cpu").manual_seed(int(seed))
136
+
137
+ try:
138
+ out = pipe(
139
+ prompt,
140
+ negative_prompt=negative_prompt if negative_prompt else None,
141
+ num_inference_steps=int(steps),
142
+ guidance_scale=float(guidance),
143
+ height=int(height),
144
+ width=int(width),
145
+ generator=gen
146
+ )
147
+ img = out.images[0]
148
+ path = save_image_to_file(img, prefix="flux")
149
+ return img, path, str(seed)
150
+ except Exception as e:
151
+ return None, None, f"Error generación FLUX: {e}"
152
+
153
+ # -----------------------
154
+ # Generación: Stable Diffusion (ligera, en CPU puede ser muy lenta)
155
+ # -----------------------
156
+ def generate_sd(prompt: str, negative_prompt: str, steps: int, guidance: float,
157
+ width: int, height: int, seed: int):
158
+ if not prompt or prompt.strip() == "":
159
+ return None, None, "Prompt vacío"
160
+ try:
161
+ pipe = load_sd_pipeline()
162
+ except Exception as e:
163
+ return None, None, f"Error cargando SD: {e}"
164
+
165
+ if seed is None or int(seed) == -1:
166
+ seed = random_seed()
167
+ gen = torch.Generator(device="cpu").manual_seed(int(seed))
168
+
169
+ try:
170
+ out = pipe(
171
+ prompt=prompt,
172
+ negative_prompt=negative_prompt if negative_prompt else None,
173
+ num_inference_steps=int(steps),
174
+ guidance_scale=float(guidance),
175
+ height=int(height),
176
+ width=int(width),
177
+ generator=gen
178
+ )
179
+ img = out.images[0]
180
+ path = save_image_to_file(img, prefix="sd")
181
+ return img, path, str(seed)
182
+ except Exception as e:
183
+ return None, None, f"Error generación SD: {e}"
184
+
185
+ # -----------------------
186
+ # Generación: REVE Create (API)
187
+ # -----------------------
188
+ def generate_reve(prompt: str, api_key: str, ratio: str, version: str, num_images: int,
189
+ width: int, height: int):
190
+ """
191
+ Llama a REVE API. Se asume que la API acepta:
192
+ payload = {"prompt": prompt, "aspect_ratio": ratio, "version": version}
193
+ Responde con base64 en 'image' en JSON (según ejemplo previo).
194
+ """
195
+ if not api_key or api_key.strip() == "":
196
+ return [], "Falta API Key"
197
+
198
+ payload = {
199
+ "prompt": prompt,
200
+ "aspect_ratio": ratio,
201
+ "version": version or "latest",
202
+ "num_images": num_images,
203
+ "width": width,
204
+ "height": height
205
+ }
206
+ headers = {
207
+ "Authorization": f"Bearer {api_key}",
208
+ "Accept": "application/json",
209
+ "Content-Type": "application/json"
210
+ }
211
+ results = []
212
+ errors = []
213
+ for i in range(int(num_images)):
214
+ try:
215
+ resp = requests.post(REVE_API_URL, headers=headers, json=payload, timeout=60)
216
+ if resp.status_code == 200:
217
+ data = resp.json()
218
+ if "image" in data:
219
+ import base64
220
+ img_bytes = base64.b64decode(data["image"])
221
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGBA").convert("RGB")
222
+ path = save_image_to_file(img, prefix="reve")
223
+ results.append((img, path))
224
+ else:
225
+ errors.append(f"Respuesta sin 'image': {data.keys()}")
226
+ else:
227
+ errors.append(f"{resp.status_code} {resp.text[:200]}")
228
+ except Exception as e:
229
+ errors.append(str(e))
230
+ if errors and not results:
231
+ return [], " | ".join(errors[:3])
232
+ return results, f"{len(results)} generadas, errores: {len(errors)}"
233
+
234
+ # -----------------------
235
+ # UI - Gradio
236
+ # -----------------------
237
+ def build_ui():
238
+ with gr.Blocks(title="BATUTO-ART MIX", theme=gr.themes.Soft()) as demo:
239
+ gr.Markdown("# 🎨 BATUTO-ART MIX — Generador (Optimizado para CPU, 9:16)")
240
+
241
+ # Estado global de historial
242
+ historial_state = gr.State([])
243
+
244
+ with gr.Tabs():
245
+ # ----------------- TAB FLUX (multi-model selector) -----------------
246
+ with gr.TabItem("FLUX Models"):
247
+ with gr.Row():
248
+ with gr.Column(scale=2):
249
+ model_flux = gr.Dropdown(
250
+ choices=[
251
+ "black-forest-labs/FLUX.1-dev",
252
+ "black-forest-labs/FLUX.1-schnell",
253
+ "black-forest-labs/FLUX.2-dev"
254
+ ],
255
+ label="Selecciona Modelo FLUX",
256
+ value="black-forest-labs/FLUX.1-schnell"
257
+ )
258
+ prompt_flux = gr.Textbox(label="Prompt", lines=4, placeholder="Describe la escena...")
259
+ negative_flux = gr.Textbox(label="Negative Prompt (opcional)", lines=2)
260
+ with gr.Row():
261
+ steps_flux = gr.Slider(4, 50, value=12, step=1, label="Steps")
262
+ guidance_flux = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="Guidance")
263
+ with gr.Row():
264
+ width_flux = gr.Number(value=DEFAULT_WIDTH, label="Width (px)")
265
+ height_flux = gr.Number(value=DEFAULT_HEIGHT, label="Height (px)")
266
+ seed_flux = gr.Number(value=-1, label="Seed (-1 random)")
267
+ with gr.Row():
268
+ gen_flux_btn = gr.Button("✨ Generar (FLUX)", variant="primary")
269
+ clear_flux_btn = gr.Button("🧹 Limpiar Prompt", variant="secondary")
270
+ download_flux_btn = gr.Button("⬇️ Guardar Última", variant="secondary")
271
+ flux_status = gr.Markdown("")
272
+
273
+ with gr.Column(scale=3):
274
+ out_flux_img = gr.Image(label="Imagen FLUX", type="pil", show_label=True)
275
+ out_flux_seed = gr.Textbox(label="Semilla usada", interactive=False)
276
+ flux_last_file = gr.Textbox(visible=False) # internal store last path
277
+ # Galería/Historial
278
+ flux_gallery = gr.Gallery(label="Historial FLUX (sesión)", columns=4, object_fit="cover", height=300)
279
+
280
+ # Conexiones FLUX
281
+ def _clear_flux():
282
+ return "", ""
283
+ clear_flux_btn.click(fn=_clear_flux, inputs=None, outputs=[prompt_flux, negative_flux])
284
+
285
+ def _download_last(flux_last):
286
+ if flux_last:
287
+ return gr.File.update(value=flux_last)
288
+ return None
289
+
290
+ download_file_flux = gr.File(label="Descargar imagen", visible=True)
291
+ download_flux_btn.click(fn=_download_last, inputs=[flux_last_file], outputs=[download_file_flux])
292
+
293
+ def _gen_flux(model_id, prompt, negative, steps, guidance, width, height, seed, hist):
294
+ img, path, info = generate_flux(model_id, prompt, negative, steps, guidance, width, height, seed)
295
+ if img is None:
296
+ return None, "", "", hist, f"❌ {info}"
297
+ # actualizar historial
298
+ new_hist = (hist or []) + [img]
299
+ return img, str(info), path, new_hist, f"✅ Generada | seed: {info}"
300
+
301
+ gen_flux_btn.click(
302
+ fn=_gen_flux,
303
+ inputs=[model_flux, prompt_flux, negative_flux, steps_flux, guidance_flux, width_flux, height_flux, seed_flux, historial_state],
304
+ outputs=[out_flux_img, out_flux_seed, flux_last_file, flux_gallery, flux_status]
305
+ ).then(lambda h: h, inputs=[historial_state], outputs=[historial_state])
306
+
307
+ # ----------------- TAB Stable Diffusion -----------------
308
+ with gr.TabItem("Stable Diffusion (opcional)"):
309
+ gr.Markdown("**Aviso:** Stable Diffusion en CPU puede ser muy lento. Úsalo si no tienes prisa.")
310
+ with gr.Row():
311
+ with gr.Column(scale=2):
312
+ prompt_sd = gr.Textbox(label="Prompt", lines=4)
313
+ negative_sd = gr.Textbox(label="Negative Prompt (opcional)", lines=2)
314
+ with gr.Row():
315
+ steps_sd = gr.Slider(10, 50, value=20, step=1, label="Steps")
316
+ guidance_sd = gr.Slider(1.0, 20.0, value=7.5, step=0.5)
317
+ with gr.Row():
318
+ width_sd = gr.Number(value=DEFAULT_WIDTH, label="Width")
319
+ height_sd = gr.Number(value=DEFAULT_HEIGHT, label="Height")
320
+ seed_sd = gr.Number(value=-1, label="Seed (-1 random)")
321
+ with gr.Row():
322
+ gen_sd_btn = gr.Button("✨ Generar (SD)", variant="primary")
323
+ clear_sd_btn = gr.Button("🧹 Limpiar Prompt", variant="secondary")
324
+ download_sd_btn = gr.Button("⬇️ Guardar Última", variant="secondary")
325
+ sd_status = gr.Markdown()
326
+ with gr.Column(scale=3):
327
+ out_sd_img = gr.Image(label="Imagen SD", type="pil")
328
+ out_sd_seed = gr.Textbox(label="Semilla usada", interactive=False)
329
+ sd_last_file = gr.Textbox(visible=False)
330
+ sd_gallery = gr.Gallery(label="Historial SD (sesión)", columns=4, object_fit="cover", height=300)
331
+
332
+ clear_sd_btn.click(fn=lambda: ("", ""), inputs=None, outputs=[prompt_sd, negative_sd])
333
+
334
+ download_file_sd = gr.File(label="Descargar imagen SD", visible=True)
335
+ download_sd_btn.click(fn=lambda p: gr.File.update(value=p) if p else None, inputs=[sd_last_file], outputs=[download_file_sd])
336
+
337
+ def _gen_sd(prompt, negative, steps, guidance, width, height, seed, hist):
338
+ img, path, info = generate_sd(prompt, negative, steps, guidance, width, height, seed)
339
+ if img is None:
340
+ return None, "", "", hist, f"❌ {info}"
341
+ new_hist = (hist or []) + [img]
342
+ return img, str(info), path, new_hist, f"✅ Generada | seed: {info}"
343
+
344
+ gen_sd_btn.click(
345
+ fn=_gen_sd,
346
+ inputs=[prompt_sd, negative_sd, steps_sd, guidance_sd, width_sd, height_sd, seed_sd, historial_state],
347
+ outputs=[out_sd_img, out_sd_seed, sd_last_file, sd_gallery, sd_status]
348
+ ).then(lambda h: h, inputs=[historial_state], outputs=[historial_state])
349
+
350
+ # ----------------- TAB REVE CREATE -----------------
351
+ with gr.TabItem("REVE Create (API)"):
352
+ gr.Markdown("Aquí puedes pegar/actualizar tu **API Key** para REVE y generar desde el front-end.")
353
+ with gr.Row():
354
+ with gr.Column(scale=2):
355
+ reve_api_input = gr.Textbox(label="🔑 REVE API Key (se puede pegar o dejar vacío para usar env var)", type="password")
356
+ reve_prompt = gr.Textbox(label="Prompt REVE", lines=4)
357
+ with gr.Row():
358
+ reve_ratio = gr.Dropdown(choices=["9:16", "16:9", "1:1"], value="9:16", label="Aspect Ratio")
359
+ reve_num = gr.Slider(1, 4, value=1, step=1, label="Cantidad")
360
+ reve_version = gr.Textbox(label="Version (dejar 'latest' o especificar)", value="latest")
361
+ with gr.Row():
362
+ reve_width = gr.Number(value=DEFAULT_WIDTH, label="Width")
363
+ reve_height = gr.Number(value=DEFAULT_HEIGHT, label="Height")
364
+ reve_gen_btn = gr.Button("🚀 Generar REVE")
365
+ reve_clear_btn = gr.Button("🧹 Limpiar Prompt REVE", variant="secondary")
366
+ reve_status = gr.Markdown()
367
+ with gr.Column(scale=3):
368
+ reve_gallery = gr.Gallery(label="Resultados REVE", columns=3, object_fit="cover", height=300)
369
+ reve_last_file = gr.Textbox(visible=False)
370
+ reve_download = gr.File(label="Descargar REVE", visible=True)
371
+
372
+ reve_clear_btn.click(fn=lambda: "", inputs=None, outputs=[reve_prompt])
373
+
374
+ def _gen_reve(prompt, api_key_ui, ratio, version, num, width, height, hist):
375
+ api_key = api_key_ui.strip() if api_key_ui and api_key_ui.strip() else os.getenv("REVE_API_KEY", "")
376
+ res, msg = generate_reve(prompt, api_key, ratio, version, int(num), int(width), int(height))
377
+ if isinstance(res, list) and res:
378
+ imgs = [r[0] for r in res]
379
+ # paths maybe multiple -> store last
380
+ last_path = res[-1][1]
381
+ new_hist = (hist or []) + imgs
382
+ return imgs, last_path, new_hist, f"✅ {msg}"
383
+ else:
384
+ return [], "", hist, f"❌ {msg}"
385
+
386
+ reve_gen_btn.click(
387
+ fn=_gen_reve,
388
+ inputs=[reve_prompt, reve_api_input, reve_ratio, reve_version, reve_num, reve_width, reve_height, historial_state],
389
+ outputs=[reve_gallery, reve_last_file, historial_state, reve_status]
390
+ )
391
+
392
+ reve_download.click(fn=lambda p: gr.File.update(value=p) if p else None, inputs=[reve_last_file], outputs=[reve_download])
393
+
394
+ # Footer / recomendaciones
395
+ gr.Markdown("---")
396
+ gr.Markdown("""
397
+ **Notas de despliegue:**
398
+ - Recomendado para **HF Spaces (CPU)**: usa `black-forest-labs/FLUX.1-schnell` por default.
399
+ - Stable Diffusion en CPU será **muy lento** y puede hacer timeouts en Spaces; úsalo con cuidado.
400
+ - REVE Create requiere tu API Key — puedes pegarla en la caja o ponerla en `REVE_API_KEY` en los Secrets del Space.
401
+ - Tamaño 9:16 por defecto: 576×1024 (balance calidad / memoria).
402
+ """)
403
+
404
+ return demo
405
+
406
+ # -----------------------
407
+ # Run
408
+ # -----------------------
409
+ if __name__ == "__main__":
410
+ demo = build_ui()
411
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))