Amir Cohen
feat: add optional lora_id/lora_scale support for per-request LoRA loading
a575c6c
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import uuid
import spaces
import torch
from diffusers import ErnieImagePipeline
from gradio import Server
from gradio.data_classes import FileData
# Optimize for performance if on GPU
torch.set_float32_matmul_precision("high")
# Initialize Pipeline
print("Loading model Baidu/ERNIE-Image-Turbo... this may take a few minutes!", flush=True)
try:
pipe = ErnieImagePipeline.from_pretrained(
"Baidu/ERNIE-Image-Turbo",
torch_dtype=torch.bfloat16,
)
print("Model loaded successfully. Moving to CUDA...", flush=True)
pipe = pipe.to("cuda")
print("Model is on CUDA. Initializing Server...", flush=True)
except Exception as e:
print(f"Error during model loading: {e}", flush=True)
raise
app = Server()
@app.api()
@spaces.GPU(duration=60)
def generate_image(
prompt: str,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 1.0,
num_inference_steps: int = 8,
use_prompt_enhancer: bool = True,
lora_id: str | None = None,
lora_scale: float = 1.0,
) -> FileData:
"""Generate an image using ERNIE-Image-Turbo.
Args:
prompt: Text description of the image to generate. Works best with detailed,
scene-style descriptions. Excels at text rendering, posters, infographics,
and complex multi-object compositions.
width: Image width in pixels. Recommended values: 1024, 848, 1264, 768, 896, 1376, 1200. Default 1024.
height: Image height in pixels. Recommended values: 1024, 1264, 848, 1376, 1200, 768, 896. Default 1024.
Use width=1264,height=848 for landscape or width=848,height=1264 for portrait.
guidance_scale: How closely to follow the prompt. Recommended: 1.0. Range 1.0-7.0.
num_inference_steps: Denoising steps. More = higher quality but slower. Range 4-30. Default 8.
use_prompt_enhancer: Enable the built-in prompt enhancer for richer outputs. Default True.
lora_id: HuggingFace repo ID of a LoRA to apply (e.g. "owner/my-lora"). Optional.
lora_scale: LoRA influence weight. Recommended 0.7–1.0. Default 1.0.
Returns:
Generated image file.
"""
print(f"Endpoint triggered! Prompt: {prompt}, width: {width}, height: {height}, use_pe: {use_prompt_enhancer}, lora_id: {lora_id}", flush=True)
lora_state = None
if lora_id:
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Find the safetensors file in the repo
from huggingface_hub import list_repo_files
repo_files = [f for f in list_repo_files(lora_id) if f.endswith(".safetensors")]
if not repo_files:
raise ValueError(f"No .safetensors found in {lora_id}")
path = hf_hub_download(lora_id, repo_files[0])
lora_state = load_file(path)
# Merge LoRA deltas directly into transformer weights
params = dict(pipe.transformer.named_parameters())
applied = 0
for key in lora_state:
if "lora_A" not in key:
continue
b_key = key.replace("lora_A", "lora_B")
if b_key not in lora_state:
continue
# Strip leading "transformer." prefix if present
param_key = key.replace("lora_A.weight", "weight")
param_key = param_key.removeprefix("diffusion_model.")
param_key = param_key.replace(".lora_A", "")
if param_key not in params:
continue
lora_A = lora_state[key].to(device=params[param_key].device, dtype=params[param_key].dtype)
lora_B = lora_state[b_key].to(device=params[param_key].device, dtype=params[param_key].dtype)
params[param_key].data += (lora_B @ lora_A) * lora_scale
applied += 1
print(f"LoRA applied: {applied} layers merged", flush=True)
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_pe=use_prompt_enhancer,
).images[0]
if lora_state is not None:
# Unmerge LoRA deltas
params = dict(pipe.transformer.named_parameters())
for key in lora_state:
if "lora_A" not in key:
continue
b_key = key.replace("lora_A", "lora_B")
if b_key not in lora_state:
continue
param_key = key.replace("lora_A.weight", "weight")
param_key = param_key.removeprefix("diffusion_model.")
param_key = param_key.replace(".lora_A", "")
if param_key not in params:
continue
lora_A = lora_state[key].to(device=params[param_key].device, dtype=params[param_key].dtype)
lora_B = lora_state[b_key].to(device=params[param_key].device, dtype=params[param_key].dtype)
params[param_key].data -= (lora_B @ lora_A) * lora_scale
# Save to a temporary unique file
os.makedirs("/tmp/ernie_outputs", exist_ok=True)
out_path = f"/tmp/ernie_outputs/{uuid.uuid4()}.png"
image.save(out_path)
return FileData(path=out_path)
from fastapi.responses import HTMLResponse
@app.get("/")
async def homepage():
"""Serve the custom frontend HTML."""
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
app.launch(show_error=True, mcp_server=True)