Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline | |
| import torch | |
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| import concurrent.futures | |
| import threading | |
| # ============================== | |
| # CONFIGURACIÓN BASE CPU | |
| # ============================== | |
| DEVICE = "cpu" | |
| torch.set_grad_enabled(False) | |
| # PARÁMETROS POR DEFECTO AJUSTADOS PARA CPU MÁS RÁPIDO | |
| DEFAULT_STEPS = 15 # Reducido de 20 para más velocidad | |
| DEFAULT_WIDTH = 512 # Reducido de 576 para menos carga | |
| DEFAULT_HEIGHT = 768 # Reducido de 1024 para menos carga (mantiene relación aproximada) | |
| def load_flux(model_id): | |
| pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) | |
| pipe.to(DEVICE) | |
| pipe.enable_attention_slicing() | |
| return pipe | |
| # Cache de modelos | |
| MODEL_CACHE = {} | |
| # ============================== | |
| # GENERADOR FLUX | |
| # ============================== | |
| def generate_flux(model_name, prompt, steps, guidance, width, height, seed): | |
| if model_name not in MODEL_CACHE: | |
| MODEL_CACHE[model_name] = load_flux(model_name) | |
| pipe = MODEL_CACHE[model_name] | |
| generator = torch.manual_seed(seed) if seed else None | |
| image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| width=width, | |
| height=height, | |
| generator=generator | |
| ).images[0] | |
| out = "/tmp/flux_output.png" | |
| image.save(out) | |
| return out | |
| # ============================== | |
| # GENERADOR SD1.5 | |
| # ============================== | |
| def load_sd15(): | |
| pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32) | |
| pipe.to(DEVICE) | |
| pipe.enable_attention_slicing() | |
| return pipe | |
| # SD15 load único | |
| def generate_sd(prompt, steps, guidance, width, height, seed): | |
| if "sd15" not in MODEL_CACHE: | |
| MODEL_CACHE["sd15"] = load_sd15() | |
| pipe = MODEL_CACHE["sd15"] | |
| generator = torch.manual_seed(seed) if seed else None | |
| image = pipe( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| width=width, | |
| height=height, | |
| generator=generator | |
| ).images[0] | |
| out = "/tmp/sd15_output.png" | |
| image.save(out) | |
| return out | |
| # ============================== | |
| # REVE CREATE MODIFICADO (Generación múltiple con threads) | |
| # ============================== | |
| def generate_single_reve_image(prompt, key, model, index, results_list, lock, progress_callback=None): | |
| """Función auxiliar para generar una sola imagen""" | |
| try: | |
| url = "https://api.reveai.xyz/v1/images" | |
| headers = {"Authorization": f"Bearer {key}"} | |
| data = {"prompt": prompt, "model": model} | |
| resp = requests.post(url, json=data, headers=headers, timeout=30) | |
| if resp.status_code != 200: | |
| print(f"Error en imagen {index+1}: {resp.status_code}") | |
| return | |
| img_url = resp.json().get("image") | |
| if not img_url: | |
| return | |
| img_data = requests.get(img_url, timeout=30).content | |
| img = Image.open(BytesIO(img_data)) | |
| out = f"/tmp/reve_{index}_{threading.current_thread().ident}.png" | |
| img.save(out) | |
| with lock: | |
| results_list.append(out) | |
| # Notificar progreso si hay callback | |
| if progress_callback: | |
| progress_callback(index + 1) | |
| except Exception as e: | |
| print(f"Error generando imagen {index+1}: {e}") | |
| def reve_generate_multiple(prompt, key, model, num_images, progress_callback=None): | |
| if not key: | |
| return None | |
| num_images = min(num_images, 8) # Máximo 8 imágenes | |
| results = [] | |
| lock = threading.Lock() | |
| # Usamos ThreadPoolExecutor para generar imágenes en paralelo | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=min(num_images, 4)) as executor: | |
| futures = [] | |
| for i in range(num_images): | |
| future = executor.submit( | |
| generate_single_reve_image, | |
| prompt, key, model, i, results, lock, progress_callback | |
| ) | |
| futures.append(future) | |
| # Esperar a que todas terminen | |
| concurrent.futures.wait(futures) | |
| return results if results else None | |
| # ============================== | |
| # UI COMPLETA | |
| # ============================== | |
| def build_ui(): | |
| with gr.Blocks(title="BATUTO-ART MIX", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🖼️ **BATUTO-ART MIX** | |
| *Generador de imágenes con FLUX, Stable Diffusion 1.5 y REVE CREATE* | |
| """) | |
| with gr.Tabs(): | |
| # ============================ | |
| # TAB: FLUX | |
| # ============================ | |
| with gr.Tab("FLUX.2 / 1-Schnell"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| flux_prompt = gr.Textbox( | |
| label="Prompt", | |
| lines=3, | |
| placeholder="Describe la imagen que quieres generar con FLUX..." | |
| ) | |
| model_select = gr.Dropdown([ | |
| "black-forest-labs/FLUX.1-schnell", | |
| "black-forest-labs/FLUX.1-dev", | |
| "black-forest-labs/FLUX.2-dev" | |
| ], value="black-forest-labs/FLUX.1-schnell", label="Modelo FLUX") | |
| with gr.Row(): | |
| steps = gr.Slider(5, 50, value=DEFAULT_STEPS, step=1, label="Steps") | |
| guidance = gr.Slider(0, 10, value=3, step=0.1, label="Guidance Scale") | |
| with gr.Row(): | |
| width = gr.Number(value=DEFAULT_WIDTH, label="Width", precision=0) | |
| height = gr.Number(value=DEFAULT_HEIGHT, label="Height", precision=0) | |
| seed = gr.Number(value=0, label="Seed (0 = aleatorio)", precision=0) | |
| btn_flux = gr.Button("✨ Generar Imagen", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| out_flux_img = gr.Image( | |
| label="Resultado FLUX", | |
| height=400, | |
| interactive=False | |
| ) | |
| out_flux_file = gr.File( | |
| label="Descargar imagen", | |
| visible=False | |
| ) | |
| # Acción del botón FLUX | |
| def generate_flux_wrapper(model_name, prompt, steps, guidance, width, height, seed): | |
| if not prompt.strip(): | |
| return None, "❌ Error: Ingresa un prompt válido" | |
| try: | |
| file_path = generate_flux( | |
| model_name, prompt, int(steps), float(guidance), | |
| int(width), int(height), int(seed) | |
| ) | |
| return file_path, gr.update(visible=True) | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| btn_flux.click( | |
| fn=generate_flux_wrapper, | |
| inputs=[model_select, flux_prompt, steps, guidance, width, height, seed], | |
| outputs=[out_flux_file, out_flux_file] | |
| ) | |
| # Mostrar imagen automáticamente | |
| out_flux_file.change( | |
| fn=lambda f: Image.open(f) if f else None, | |
| inputs=[out_flux_file], | |
| outputs=[out_flux_img] | |
| ) | |
| # ============================ | |
| # TAB: SD1.5 | |
| # ============================ | |
| with gr.Tab("Stable Diffusion 1.5"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sd_prompt = gr.Textbox( | |
| label="Prompt", | |
| lines=3, | |
| placeholder="Describe la imagen que quieres generar con SD1.5..." | |
| ) | |
| with gr.Row(): | |
| sd_steps = gr.Slider(5, 50, value=DEFAULT_STEPS, step=1, label="Steps") | |
| sd_guidance = gr.Slider(0, 10, value=3, step=0.1, label="Guidance Scale") | |
| with gr.Row(): | |
| sd_width = gr.Number(value=DEFAULT_WIDTH, label="Width", precision=0) | |
| sd_height = gr.Number(value=DEFAULT_HEIGHT, label="Height", precision=0) | |
| sd_seed = gr.Number(value=0, label="Seed (0 = aleatorio)", precision=0) | |
| btn_sd = gr.Button("✨ Generar Imagen", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| out_sd_img = gr.Image( | |
| label="Resultado SD1.5", | |
| height=400, | |
| interactive=False | |
| ) | |
| out_sd_file = gr.File( | |
| label="Descargar imagen", | |
| visible=False | |
| ) | |
| # Acción del botón SD1.5 | |
| def generate_sd_wrapper(prompt, steps, guidance, width, height, seed): | |
| if not prompt.strip(): | |
| return None, "❌ Error: Ingresa un prompt válido" | |
| try: | |
| file_path = generate_sd( | |
| prompt, int(steps), float(guidance), | |
| int(width), int(height), int(seed) | |
| ) | |
| return file_path, gr.update(visible=True) | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| btn_sd.click( | |
| fn=generate_sd_wrapper, | |
| inputs=[sd_prompt, sd_steps, sd_guidance, sd_width, sd_height, sd_seed], | |
| outputs=[out_sd_file, out_sd_file] | |
| ) | |
| out_sd_file.change( | |
| fn=lambda f: Image.open(f) if f else None, | |
| inputs=[out_sd_file], | |
| outputs=[out_sd_img] | |
| ) | |
| # ============================ | |
| # TAB: REVE CREATE MODIFICADA | |
| # ============================ | |
| with gr.Tab("REVE CREATE"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| reve_api = gr.Textbox( | |
| label="API Key REVE", | |
| type="password", | |
| placeholder="Ingresa tu API key de REVE", | |
| info="Necesitas una clave API válida de REVE" | |
| ) | |
| reve_prompt = gr.Textbox( | |
| label="Prompt", | |
| lines=3, | |
| placeholder="Describe la imagen que quieres generar..." | |
| ) | |
| reve_model = gr.Dropdown([ | |
| "reve-1", | |
| "reve-2", | |
| "reve-fast" | |
| ], value="reve-fast", label="Modelo REVE") | |
| # Slider para cantidad de imágenes | |
| num_images_slider = gr.Slider( | |
| minimum=1, | |
| maximum=8, | |
| value=4, | |
| step=1, | |
| label="Cantidad de imágenes a generar" | |
| ) | |
| with gr.Row(): | |
| btn_reve = gr.Button("🚀 Generar Imágenes", variant="primary", size="lg") | |
| btn_clear = gr.Button("🗑️ Limpiar", variant="secondary") | |
| # Indicador de progreso | |
| progress_info = gr.Markdown("Esperando generación...") | |
| with gr.Column(scale=2): | |
| # Galería para mostrar múltiples imágenes | |
| gallery = gr.Gallery( | |
| label="Imágenes generadas", | |
| show_label=True, | |
| columns=4, | |
| rows=2, | |
| height="auto", | |
| object_fit="contain" | |
| ) | |
| # Información de resultados | |
| result_info = gr.Markdown("") | |
| # Botón de descarga | |
| with gr.Row(): | |
| btn_download_all = gr.Button( | |
| "📥 Descargar todas las imágenes", | |
| size="lg", | |
| variant="secondary" | |
| ) | |
| download_output = gr.File( | |
| label="Archivos descargables", | |
| file_count="multiple", | |
| visible=False | |
| ) | |
| # Estado para almacenar las rutas de archivos | |
| last_files = gr.State([]) | |
| # Función para generar múltiples imágenes con progreso | |
| def generate_and_update_gallery(prompt, key, model, num_images, progress=gr.Progress()): | |
| if not key: | |
| return [], [], "❌ Error: Ingresa una API key válida", progress_info.update(value="") | |
| if not prompt.strip(): | |
| return [], [], "❌ Error: Ingresa un prompt válido", progress_info.update(value="") | |
| progress_info.update(value="⏳ Iniciando generación...") | |
| # Función de callback para progreso | |
| def update_progress(current): | |
| progress_info.update(value=f"⏳ Generando imagen {current}/{num_images}...") | |
| files = reve_generate_multiple(prompt, key, model, num_images, update_progress) | |
| if files: | |
| # Crear lista de imágenes para la galería | |
| images = [(file,) for file in files] | |
| info_text = f"✅ Generadas {len(files)} de {num_images} imágenes" | |
| progress_info.update(value="✅ Generación completada") | |
| return images, files, info_text | |
| else: | |
| progress_info.update(value="❌ Error en la generación") | |
| return [], [], "❌ Error: No se pudieron generar imágenes. Verifica tu API key y conexión." | |
| # Acción del botón de generación | |
| btn_reve.click( | |
| fn=generate_and_update_gallery, | |
| inputs=[reve_prompt, reve_api, reve_model, num_images_slider], | |
| outputs=[gallery, last_files, result_info] | |
| ) | |
| # Acción del botón de descarga | |
| def prepare_download(files): | |
| if files: | |
| return gr.update(value=files, visible=True) | |
| return gr.update(value=[], visible=False) | |
| btn_download_all.click( | |
| fn=prepare_download, | |
| inputs=[last_files], | |
| outputs=[download_output] | |
| ) | |
| # Acción del botón de limpiar | |
| def clear_gallery(): | |
| return [], [], "✅ Galería limpiada", gr.update(value="Listo para nueva generación") | |
| btn_clear.click( | |
| fn=clear_gallery, | |
| outputs=[gallery, last_files, result_info, progress_info] | |
| ).then( | |
| fn=lambda: gr.update(visible=False), | |
| outputs=[download_output] | |
| ) | |
| return demo | |
| # ============================== | |
| # EJECUCIÓN PRINCIPAL | |
| # ============================== | |
| if __name__ == "__main__": | |
| # Configuración para HuggingFace Spaces | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true", help="Crear enlace público") | |
| parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Dirección del servidor") | |
| parser.add_argument("--server-port", type=int, default=7860, help="Puerto del servidor") | |
| args = parser.parse_args() | |
| print("=" * 50) | |
| print("🚀 Iniciando BATUTO-ART MIX") | |
| print("=" * 50) | |
| print(f"📱 Dispositivo: {DEVICE}") | |
| print(f"⚡ Steps por defecto: {DEFAULT_STEPS}") | |
| print(f"📐 Resolución por defecto: {DEFAULT_WIDTH}x{DEFAULT_HEIGHT}") | |
| print("=" * 50) | |
| demo = build_ui() | |
| demo.launch( | |
| share=args.share, | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| show_error=True | |
| ) |