family / app.py
vcollos's picture
fix(runtime): align flux lora deps with working reference\n\nalinha diffusers transformers accelerate e peft com o setup do space multimodalart/flux-lora-the-explorer\nadiciona log das versoes carregadas e do flag USE_PEFT_BACKEND\nimpacto: deve habilitar o backend peft usado por load_lora_weights no runtime
4961367
import spaces
import gradio as gr
import torch
from PIL import Image
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
import random
import os
import json
import io
import time
import importlib
from datetime import datetime
from huggingface_hub import HfFileSystem, ModelCard
from gradio_client import utils as gradio_client_utils
# Configuração de Logging
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def patch_gradio_api_schema():
if getattr(gradio_client_utils, "_codex_bool_schema_patch", False):
return
original_json_schema_to_python_type = gradio_client_utils._json_schema_to_python_type
def patched_json_schema_to_python_type(schema, defs):
if isinstance(schema, bool):
return "Any" if schema else "None"
return original_json_schema_to_python_type(schema, defs)
gradio_client_utils._json_schema_to_python_type = patched_json_schema_to_python_type
gradio_client_utils._codex_bool_schema_patch = True
patch_gradio_api_schema()
def log_runtime_versions():
package_names = ["diffusers", "transformers", "accelerate", "peft"]
versions = {}
for package_name in package_names:
try:
module = importlib.import_module(package_name)
versions[package_name] = getattr(module, "__version__", "unknown")
except Exception as exc:
versions[package_name] = f"missing ({exc})"
try:
from diffusers.utils.constants import USE_PEFT_BACKEND
versions["USE_PEFT_BACKEND"] = USE_PEFT_BACKEND
except Exception as exc:
versions["USE_PEFT_BACKEND"] = f"unknown ({exc})"
logger.info("Runtime packages: %s", versions)
log_runtime_versions()
# Supabase (opcional)
try:
from supabase import create_client, Client
url: str = os.getenv('SUPABASE_URL')
key: str = os.getenv('SUPABASE_KEY')
supabase: Client = create_client(url, key) if url and key else None
supabase_enabled = True if supabase else False
logger.info("Supabase inicializado" if supabase_enabled else "Supabase não configurado")
except Exception as e:
logger.warning(f"Erro ao inicializar Supabase: {e}")
supabase_enabled = False
supabase = None
# Obtém token da Hugging Face
hf_token = os.getenv("HF_TOKEN")
def require_hf_token():
token = (hf_token or os.getenv("HUGGING_FACE_HUB_TOKEN") or "").strip().strip('"').strip("'")
if not token:
raise gr.Error(
"Defina o secret HF_TOKEN com acesso ao modelo base black-forest-labs/FLUX.1-dev "
"e aceite a licença na conta usada pelo Space."
)
return token
# Seed máxima
MAX_SEED = 2**32 - 1
# Carregar modelos LoRA
loras = [
{
"image": "https://huggingface.co/front/assets/huggingface_logo-noborder.svg",
"title": "Paula",
"repo": "vcollos/pp2",
"weights": "lora.safetensors",
"trigger_word": "Paula"
}
]
# Inicializar modelo
base_model = "black-forest-labs/FLUX.1-dev"
logger.info(f"Inicializando modelo base: {base_model}")
class TimeMeasure:
def __init__(self, name=""):
self.name = name
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
self.duration = time.time() - self.start
logger.info(f"🕒 {self.name}: {self.duration:.2f} segundos")
# Upload de imagem para o Supabase (se configurado)
def upload_image_to_supabase(image, filename):
if not supabase_enabled:
return None
img_bytes = io.BytesIO()
image.save(img_bytes, format="PNG")
img_bytes.seek(0)
storage_path = f"images/{filename}"
try:
supabase.storage.from_("images").upload(storage_path, img_bytes.getvalue(), {"content-type": "image/png"})
base_url = f"{url}/storage/v1/object/public/images"
return f"{base_url}/{storage_path}"
except Exception as e:
logger.error(f"Erro no upload da imagem: {e}")
return None
# Função para processar a seleção de modelos na interface
def update_selection(evt: gr.SelectData):
selected_lora = loras[evt.index]
new_placeholder = f"Digite o prompt para {selected_lora['title']}, de preferência em inglês."
lora_repo = selected_lora["repo"]
updated_text = f"### Selecionado: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
)
# Carrega modelo personalizado
def add_custom_lora(custom_lora):
global loras
if not custom_lora:
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None
try:
model_id = custom_lora
if model_id.startswith("https://huggingface.co/"):
model_id = model_id.replace("https://huggingface.co/", "")
logger.info(f"Verificando modelo: {model_id}")
# Verificar se é um modelo FLUX LoRA válido
fs = HfFileSystem()
# Verificar card do modelo
try:
model_card = ModelCard.load(model_id)
base_model = model_card.data.get("base_model")
if base_model != "black-forest-labs/FLUX.1-dev" and base_model != "black-forest-labs/FLUX.1-schnell":
raise gr.Error("Este modelo não é um LoRA do FLUX")
title = model_id.split("/")[-1]
trigger_word = model_card.data.get("instance_prompt", "")
# Encontrar imagem de exemplo
card_image = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", "")
image_url = f"https://huggingface.co/{model_id}/resolve/main/{card_image}" if card_image else None
except Exception as e:
logger.warning(f"Erro ao carregar card: {e}, tentando método alternativo")
title = model_id.split("/")[-1]
trigger_word = ""
image_url = None
# Encontrar arquivo de pesos e imagem
weight_file = None
try:
files = fs.ls(model_id, detail=False)
for file in files:
filename = file.split("/")[-1]
# Encontrar arquivo de pesos
if filename.endswith(".safetensors"):
weight_file = filename
# Encontrar imagem se não encontrada no card
if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
image_url = f"https://huggingface.co/{model_id}/resolve/main/{filename}"
except Exception as e:
logger.error(f"Erro ao listar arquivos: {e}")
raise gr.Error(f"Não foi possível acessar o repositório: {str(e)}")
if not weight_file:
raise gr.Error("Nenhum arquivo .safetensors encontrado no repositório")
# Se não encontrou imagem, usar uma placeholder
if not image_url:
image_url = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
# Criar card HTML
card = f'''
<div class="custom_lora_card">
<span>LoRA carregado com sucesso:</span>
<div class="card_internal">
<img src="{image_url}" style="max-width: 100px; max-height: 100px;"/>
<div>
<h3>{title}</h3>
<small>{"Usando: <code><b>"+trigger_word+"</code></b> como palavra-chave" if trigger_word else "Não encontramos a palavra-chave, se tiver, coloque-a no prompt."}<br></small>
</div>
</div>
</div>
'''
# Verificar se já existe na lista
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == model_id), None)
if existing_item_index is None:
new_item = {
"image": image_url,
"title": title,
"repo": model_id,
"weights": weight_file,
"trigger_word": trigger_word,
}
existing_item_index = len(loras)
loras.append(new_item)
logger.info(f"Adicionado novo modelo: {title}")
return (
gr.update(visible=True, value=card),
gr.update(visible=True),
gr.Gallery(value=[(item["image"], item["title"]) for item in loras]),
f"Modelo: {title}",
existing_item_index,
)
except Exception as e:
logger.error(f"Erro ao adicionar modelo: {e}")
error_msg = f"Modelo inválido: {str(e)}"
return (
gr.update(visible=True, value=error_msg),
gr.update(visible=False),
gr.update(),
"",
None,
)
def remove_custom_lora():
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None
# Função principal para gerar imagem
@spaces.GPU(duration=60)
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
try:
if selected_index is None:
selected_index = 0
token = require_hf_token()
# Inicializa o timestamp
start_time = time.time()
logger.info(f"Iniciando geração com modelo: {loras[selected_index]['title']}")
# Configuração do modelo
with TimeMeasure("Carregando modelo base"):
# Usar bfloat16 para economizar memória
pipe = DiffusionPipeline.from_pretrained(
base_model,
torch_dtype=torch.float16,
use_safetensors=True,
token=token,
)
pipe.to("cuda")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
lora_weights = selected_lora.get("weights")
trigger_word = selected_lora.get("trigger_word", "")
qualidade = "<flux.1-dev>"
# Adiciona trigger word ao prompt
if trigger_word:
prompt_full = f"{trigger_word} {prompt} {qualidade}"
else:
prompt_full = f"{prompt} {qualidade}"
# Randomiza a seed se necessário
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Carrega o LoRA específico
with TimeMeasure(f"Carregando LoRA {selected_lora['title']}"):
try:
pipe.load_lora_weights(
lora_path,
weight_name=lora_weights,
adapter_name="lora",
token=token,
)
pipe.set_adapters(["lora"], adapter_weights=[lora_scale])
except Exception as e:
logger.error(f"Erro ao carregar LoRA: {e}")
raise gr.Error(f"Erro ao carregar LoRA: {str(e)}")
# Gera a imagem
with TimeMeasure("Gerando imagem"):
result = pipe(
prompt=prompt_full,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator
)
image = result.images[0]
# Salva a imagem no Supabase se configurado
if supabase_enabled:
try:
filename = f"image_{seed}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}.png"
image_url = upload_image_to_supabase(image, filename)
if image_url:
logger.info(f"Imagem salva no Supabase: {image_url}")
# Salva metadados
supabase.table("images").insert({
"prompt": prompt_full,
"cfg_scale": cfg_scale,
"steps": steps,
"seed": seed,
"lora_scale": lora_scale,
"image_url": image_url,
"created_at": datetime.utcnow().isoformat()
}).execute()
logger.info("Metadados salvos no Supabase")
except Exception as e:
logger.error(f"Erro ao salvar no Supabase: {e}")
elapsed_time = time.time() - start_time
logger.info(f"Imagem gerada em {elapsed_time:.2f} segundos")
# Limpa memória CUDA
torch.cuda.empty_cache()
return image, seed
except Exception as e:
logger.error(f"Erro na geração: {e}")
raise gr.Error(str(e))
# Interface Gradio
collos = gr.themes.Soft(
primary_hue="gray",
secondary_hue="stone",
neutral_hue="slate",
radius_size=gr.themes.Size(lg="15px", md="8px", sm="6px", xl="16px", xs="4px", xxl="24px", xxs="2px")
).set(
body_background_fill='*primary_100',
embed_radius='*radius_lg',
shadow_drop='0 1px 2px rgba(0, 0, 0, 0.1)',
shadow_drop_lg='0 1px 2px rgba(0, 0, 0, 0.1)',
shadow_inset='0 1px 2px rgba(0, 0, 0, 0.1)',
shadow_spread='0 1px 2px rgba(0, 0, 0, 0.1)',
shadow_spread_dark='0 1px 2px rgba(0, 0, 0, 0.1)',
block_radius='*radius_lg',
block_shadow='*shadow_drop',
container_radius='*radius_lg'
)
css = """
.custom_lora_card {
padding: 10px;
background-color: #f5f5f5;
border-radius: 10px;
margin-top: 10px;
}
.card_internal {
display: flex;
align-items: center;
margin-top: 10px;
}
.card_internal img {
margin-right: 15px;
border-radius: 5px;
}
"""
with gr.Blocks(theme=collos, css=css) as app:
# Logo
title = gr.HTML(
"""<img src="https://huggingface.co/spaces/vcollos/Uniodonto/resolve/main/logo/logo_collos_3.png" alt="Logo" style="display: block; margin: 0 auto; padding: 5px 0px 20px 0px; width: 200px;" />""",
elem_id="title",
)
selected_index = gr.State(0)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Digite o prompt para Paula, de preferência em inglês.")
with gr.Column(scale=1):
generate_button = gr.Button("Gerar Imagem", variant="primary", elem_id="cta")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("### Selecionado: [vcollos/pp2](https://huggingface.co/vcollos/pp2) ✅")
gallery = gr.Gallery(
label="Modelo Disponível",
value=[(item["image"], item["title"]) for item in loras],
allow_preview=False,
columns=1,
show_share_button=False
)
gr.Markdown("Somente o modelo `vcollos/pp2` está habilitado neste momento.")
with gr.Column():
result = gr.Image(label="Imagem Gerada")
seed_output = gr.Number(label="Seed", precision=0)
with gr.Row():
with gr.Accordion("Configurações Avançadas", open=False):
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=32)
with gr.Row():
width = gr.Slider(label="Largura", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Altura", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Seed Aleatória")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="Intensidade do LoRA", minimum=0, maximum=3, step=0.01, value=1.20)
# Eventos
gallery.select(
update_selection,
inputs=[],
outputs=[prompt, selected_info, selected_index]
)
generate_inputs = [
prompt, cfg_scale, steps, selected_index,
randomize_seed, seed, width, height, lora_scale
]
generate_outputs = [result, seed_output]
generate_button.click(run_lora, inputs=generate_inputs, outputs=generate_outputs)
prompt.submit(run_lora, inputs=generate_inputs, outputs=generate_outputs)
# Iniciar o app
app.queue()
app.launch()