fyla-image-1.0 / app.py
deivissonfloriano's picture
Update app.py
fc093a9 verified
import os, random, gc, shutil, pathlib
import numpy as np
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import login
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLRefinerPipeline,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
DPMSolverMultistepScheduler,
DiffusionPipeline,
)
from diffusers.utils import load_image
import diffusers as _diff
# =========================
# 0) Autenticação opcional (para modelos gated)
# =========================
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
try:
login(HF_TOKEN)
print("[OK] HF login")
except Exception as e:
print("[WARN] HF login:", e)
# =========================
# 1) Hygiene de cache/armazenamento
# =========================
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
CACHE_DIR = "/home/user/.cache/huggingface"
try:
if os.path.exists(CACHE_DIR):
shutil.rmtree(CACHE_DIR)
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
print("[OK] Cache limpo:", CACHE_DIR)
except Exception as e:
print("[WARN] cache:", e)
print("diffusers:", _diff.__version__)
print("torch:", torch.__version__)
# =========================
# 2) Dispositivo/presets
# =========================
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
if device == "cpu":
DEFAULT_W, DEFAULT_H, DEFAULT_STEPS, DEFAULT_CFG = 768, 768, 22, 4.5
else:
DEFAULT_W, DEFAULT_H, DEFAULT_STEPS, DEFAULT_CFG = 1024, 1024, 28, 5.0
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE = 1024
# =========================
# 3) IDs de modelos
# =========================
# SDXL 2-stage
SDXL_BASE = "stabilityai/stable-diffusion-xl-base-1.0"
SDXL_REFINER = "stabilityai/stable-diffusion-xl-refiner-1.0"
USE_REFINER = True
# IP-Adapter SDXL
USE_IP_ADAPTER = True
IP_REPO = "h94/IP-Adapter"
IP_SUBFOLDER = "models/ip-adapter_sdxl"
IP_WEIGHT = "ip-adapter_sdxl.bin" # troque se usar outra variante (ex.: ViT-H específico para SDXL)
# ControlNet (opcional)
USE_CONTROLNET = False
CONTROLNET_OPENPOSE = "thibaud/controlnet-openpose-sdxl-1.0"
# Hunyuan (Modo Plus)
HUNYUAN_ID = "tencent/HunyuanImage-3.0"
# =========================
# 4) Utils
# =========================
def to_img(x):
if x is None:
return None
if isinstance(x, Image.Image):
return x.convert("RGB")
return load_image(x)
def free_mem():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# =========================
# 5) Loaders sob demanda
# =========================
_base_pipe = None
_refiner_pipe = None
_control_pipe = None
_img2img_pipe = None
_hunyuan_pipe = None
def get_base():
global _base_pipe
if _base_pipe is None:
pipe = StableDiffusionXLPipeline.from_pretrained(
SDXL_BASE, torch_dtype=dtype, use_safetensors=True, cache_dir=CACHE_DIR
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
pipe.vae.enable_tiling()
if device == "cuda":
try: pipe.enable_model_cpu_offload()
except: pipe = pipe.to(device)
try: pipe.enable_xformers_memory_efficient_attention()
except Exception as e: print("[WARN] xformers:", e)
else:
pipe = pipe.to(device)
if USE_IP_ADAPTER:
try:
pipe.load_ip_adapter(IP_REPO, subfolder=IP_SUBFOLDER, weight_name=IP_WEIGHT)
pipe.set_ip_adapter_scale(0.8)
print("[OK] IP-Adapter (BASE)")
except Exception as e:
print("[WARN] IP-Adapter:", e)
_base_pipe = pipe
return _base_pipe
def get_refiner():
if not USE_REFINER:
return None
global _refiner_pipe
if _refiner_pipe is None:
pipe = StableDiffusionXLRefinerPipeline.from_pretrained(
SDXL_REFINER, torch_dtype=dtype, use_safetensors=True, cache_dir=CACHE_DIR
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
pipe.vae.enable_tiling()
if device == "cuda":
try: pipe.enable_model_cpu_offload()
except: pipe = pipe.to(device)
try: pipe.enable_xformers_memory_efficient_attention()
except Exception as e: print("[WARN] xformers (refiner):", e)
else:
pipe = pipe.to(device)
_refiner_pipe = pipe
return _refiner_pipe
def get_img2img():
global _img2img_pipe
if _img2img_pipe is None:
base = get_base()
pipe = StableDiffusionXLImg2ImgPipeline(
vae=base.vae, text_encoder=base.text_encoder, tokenizer=base.tokenizer,
unet=base.unet, scheduler=base.scheduler,
text_encoder_2=base.text_encoder_2, tokenizer_2=base.tokenizer_2,
feature_extractor=base.feature_extractor, image_encoder=base.image_encoder,
)
if USE_IP_ADAPTER:
try:
pipe.load_ip_adapter(IP_REPO, subfolder=IP_SUBFOLDER, weight_name=IP_WEIGHT)
pipe.set_ip_adapter_scale(0.8)
print("[OK] IP-Adapter (IMG2IMG)")
except Exception as e:
print("[WARN] IP-Adapter (img2img):", e)
pipe.enable_attention_slicing()
pipe.vae.enable_tiling()
if device == "cuda":
try: pipe.enable_model_cpu_offload()
except: pipe = pipe.to(device)
try: pipe.enable_xformers_memory_efficient_attention()
except Exception as e: print("[WARN] xformers (img2img):", e)
else:
pipe = pipe.to(device)
_img2img_pipe = pipe
return _img2img_pipe
def get_control():
if not USE_CONTROLNET:
return None
global _control_pipe
if _control_pipe is None:
base = get_base()
cn = ControlNetModel.from_pretrained(
CONTROLNET_OPENPOSE, torch_dtype=dtype, use_safetensors=True, cache_dir=CACHE_DIR
)
pipe = StableDiffusionXLControlNetPipeline(
vae=base.vae, text_encoder=base.text_encoder, tokenizer=base.tokenizer,
unet=base.unet, scheduler=base.scheduler,
text_encoder_2=base.text_encoder_2, tokenizer_2=base.tokenizer_2,
controlnet=cn, feature_extractor=base.feature_extractor,
image_encoder=base.image_encoder, requires_safety_checker=False,
)
pipe.enable_attention_slicing()
pipe.vae.enable_tiling()
if device == "cuda":
try: pipe.enable_model_cpu_offload()
except: pipe = pipe.to(device)
try: pipe.enable_xformers_memory_efficient_attention()
except Exception as e: print("[WARN] xformers (control):", e)
else:
pipe = pipe.to(device)
_control_pipe = pipe
return _control_pipe
def get_hunyuan():
global _hunyuan_pipe
if _hunyuan_pipe is None:
pipe = DiffusionPipeline.from_pretrained(
HUNYUAN_ID, torch_dtype=dtype, use_safetensors=True, cache_dir=CACHE_DIR
)
if device == "cuda":
try: pipe.enable_model_cpu_offload()
except: pipe = pipe.to(device)
try: pipe.enable_xformers_memory_efficient_attention()
except Exception as e: print("[WARN] xformers (hunyuan):", e)
else:
pipe = pipe.to(device)
_hunyuan_pipe = pipe
return _hunyuan_pipe
# =========================
# 6) Inferência SDXL 2-stage
# =========================
def infer_sdxl(
prompt, negative, seed, rnd, width, height, cfg, steps,
ref_face, ip_scale, use_img2img, img2img_strength, pose_img,
use_refiner, refiner_fraction
):
if rnd: seed = random.randint(0, MAX_SEED)
gen = torch.Generator(device=device).manual_seed(int(seed))
base = get_base()
img2img = get_img2img() if use_img2img else None
control = get_control() if (pose_img is not None) else None
refiner = get_refiner() if use_refiner else None
if USE_IP_ADAPTER:
try:
base.set_ip_adapter_scale(float(ip_scale))
if img2img: img2img.set_ip_adapter_scale(float(ip_scale))
except Exception: pass
ref = to_img(ref_face)
pose = to_img(pose_img)
kwargs = dict(
prompt=prompt, negative_prompt=negative or None,
width=int(width), height=int(height),
guidance_scale=float(cfg),
num_inference_steps=int(steps),
generator=gen,
)
if USE_IP_ADAPTER and ref is not None:
kwargs["ip_adapter_image"] = ref
if control and pose is not None:
first_img = control(image=pose, **kwargs).images[0]
elif img2img and ref is not None and img2img_strength > 0:
first_img = img2img(image=ref, strength=float(img2img_strength), **kwargs).images[0]
else:
first_img = base(**kwargs).images[0]
if refiner:
ref_steps = max(5, int(steps * float(refiner_fraction)))
with torch.inference_mode():
out = refiner(
prompt=prompt,
negative_prompt=negative or None,
image=first_img,
num_inference_steps=ref_steps,
guidance_scale=float(cfg),
generator=gen,
).images[0]
else:
out = first_img
free_mem()
return out, seed
# =========================
# 7) Inferência Hunyuan (Plus)
# =========================
def infer_hunyuan(prompt, negative, seed, rnd, width, height, cfg, steps):
if rnd: seed = random.randint(0, MAX_SEED)
gen = torch.Generator(device=device).manual_seed(int(seed))
hy = get_hunyuan()
img = hy(
prompt=prompt,
negative_prompt=negative or None,
width=int(width), height=int(height),
guidance_scale=float(cfg),
num_inference_steps=int(steps),
generator=gen,
).images[0]
free_mem()
return img, seed
# =========================
# 8) Wrapper de escolha
# =========================
def infer_plus(
engine,
prompt, negative, seed, rnd, width, height, cfg, steps,
ref_face, ip_scale, use_img2img, img2img_strength, pose_img,
use_refiner, refiner_fraction
):
if engine == "HunyuanImage 3.0":
return infer_hunyuan(prompt, negative, seed, rnd, width, height, cfg, steps)
else:
return infer_sdxl(
prompt, negative, seed, rnd, width, height, cfg, steps,
ref_face, ip_scale, use_img2img, img2img_strength, pose_img,
use_refiner, refiner_fraction
)
# =========================
# 9) UI
# =========================
examples = [
"professional portrait of a young adult with short wavy brown hair, freckles, round glasses, natural light, photo-realistic",
"same person, sitting inside a modern car, shallow depth of field, photo-realistic",
]
with gr.Blocks(css="#col{max-width:780px;margin:0 auto;}") as demo:
gr.Markdown("## SDXL 2-Stage + IP-Adapter + ControlNet (opcional) + HunyuanImage 3.0 (Plus)")
engine = gr.Dropdown(
choices=["SDXL (2-stage)", "HunyuanImage 3.0"],
value="SDXL (2-stage)",
label="Engine",
)
with gr.Column(elem_id="col"):
prompt = gr.Textbox(label="Prompt", lines=2, value=examples[0])
negative = gr.Textbox(label="Negative", lines=2, value="low quality, bad anatomy, text, watermark")
with gr.Row():
ref_face = gr.Image(type="pil", label="Referência (rosto) — SDXL")
pose = gr.Image(type="pil", label="Pose (OpenPose) — SDXL")
with gr.Accordion("Avançado", open=False):
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
rnd = gr.Checkbox(value=False, label="Randomize seed")
width = gr.Slider(512, MAX_IMAGE, value=DEFAULT_W, step=64, label="Width")
height = gr.Slider(512, MAX_IMAGE, value=DEFAULT_H, step=64, label="Height")
cfg = gr.Slider(1.0, 12.0, value=DEFAULT_CFG, step=0.5, label="CFG")
steps = gr.Slider(10, 50, value=DEFAULT_STEPS, step=1, label="Steps")
ip_scale = gr.Slider(0.0, 1.2, value=0.8, step=0.05, label="IP-Adapter scale (SDXL)")
use_img2img = gr.Checkbox(value=True, label="Usar img2img com a referência (SDXL)")
img2img_strength = gr.Slider(0.1, 0.8, value=0.35, step=0.05, label="Img2Img strength (SDXL)")
use_refiner = gr.Checkbox(value=True, label="Usar Refiner (SDXL)")
refiner_fraction = gr.Slider(0.1, 0.6, value=0.25, step=0.05, label="Fraç. de steps no Refiner (SDXL)")
run = gr.Button("Gerar", variant="primary")
out = gr.Image(label="Resultado")
gr.Examples(examples=examples, inputs=[prompt])
run.click(
fn=infer_plus,
inputs=[engine, prompt, negative, seed, rnd, width, height, cfg, steps,
ref_face, ip_scale, use_img2img, img2img_strength, pose,
use_refiner, refiner_fraction],
outputs=[out, seed],
)
if __name__ == "__main__":
demo.launch()