Spaces:
Running on Zero
Running on Zero
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| import uuid | |
| import spaces | |
| import torch | |
| from diffusers import ErnieImagePipeline | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| # Optimize for performance if on GPU | |
| torch.set_float32_matmul_precision("high") | |
| # Initialize Pipeline | |
| print("Loading model Baidu/ERNIE-Image-Turbo... this may take a few minutes!", flush=True) | |
| try: | |
| pipe = ErnieImagePipeline.from_pretrained( | |
| "Baidu/ERNIE-Image-Turbo", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| print("Model loaded successfully. Moving to CUDA...", flush=True) | |
| pipe = pipe.to("cuda") | |
| print("Model is on CUDA. Initializing Server...", flush=True) | |
| except Exception as e: | |
| print(f"Error during model loading: {e}", flush=True) | |
| raise | |
| app = Server() | |
| def generate_image( | |
| prompt: str, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 1.0, | |
| num_inference_steps: int = 8, | |
| use_prompt_enhancer: bool = True, | |
| lora_id: str | None = None, | |
| lora_scale: float = 1.0, | |
| ) -> FileData: | |
| """Generate an image using ERNIE-Image-Turbo. | |
| Args: | |
| prompt: Text description of the image to generate. Works best with detailed, | |
| scene-style descriptions. Excels at text rendering, posters, infographics, | |
| and complex multi-object compositions. | |
| width: Image width in pixels. Recommended values: 1024, 848, 1264, 768, 896, 1376, 1200. Default 1024. | |
| height: Image height in pixels. Recommended values: 1024, 1264, 848, 1376, 1200, 768, 896. Default 1024. | |
| Use width=1264,height=848 for landscape or width=848,height=1264 for portrait. | |
| guidance_scale: How closely to follow the prompt. Recommended: 1.0. Range 1.0-7.0. | |
| num_inference_steps: Denoising steps. More = higher quality but slower. Range 4-30. Default 8. | |
| use_prompt_enhancer: Enable the built-in prompt enhancer for richer outputs. Default True. | |
| lora_id: HuggingFace repo ID of a LoRA to apply (e.g. "owner/my-lora"). Optional. | |
| lora_scale: LoRA influence weight. Recommended 0.7–1.0. Default 1.0. | |
| Returns: | |
| Generated image file. | |
| """ | |
| print(f"Endpoint triggered! Prompt: {prompt}, width: {width}, height: {height}, use_pe: {use_prompt_enhancer}, lora_id: {lora_id}", flush=True) | |
| lora_state = None | |
| if lora_id: | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| # Find the safetensors file in the repo | |
| from huggingface_hub import list_repo_files | |
| repo_files = [f for f in list_repo_files(lora_id) if f.endswith(".safetensors")] | |
| if not repo_files: | |
| raise ValueError(f"No .safetensors found in {lora_id}") | |
| path = hf_hub_download(lora_id, repo_files[0]) | |
| lora_state = load_file(path) | |
| # Merge LoRA deltas directly into transformer weights | |
| params = dict(pipe.transformer.named_parameters()) | |
| applied = 0 | |
| for key in lora_state: | |
| if "lora_A" not in key: | |
| continue | |
| b_key = key.replace("lora_A", "lora_B") | |
| if b_key not in lora_state: | |
| continue | |
| # Strip leading "transformer." prefix if present | |
| param_key = key.replace("lora_A.weight", "weight") | |
| param_key = param_key.removeprefix("diffusion_model.") | |
| param_key = param_key.replace(".lora_A", "") | |
| if param_key not in params: | |
| continue | |
| lora_A = lora_state[key].to(device=params[param_key].device, dtype=params[param_key].dtype) | |
| lora_B = lora_state[b_key].to(device=params[param_key].device, dtype=params[param_key].dtype) | |
| params[param_key].data += (lora_B @ lora_A) * lora_scale | |
| applied += 1 | |
| print(f"LoRA applied: {applied} layers merged", flush=True) | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| use_pe=use_prompt_enhancer, | |
| ).images[0] | |
| if lora_state is not None: | |
| # Unmerge LoRA deltas | |
| params = dict(pipe.transformer.named_parameters()) | |
| for key in lora_state: | |
| if "lora_A" not in key: | |
| continue | |
| b_key = key.replace("lora_A", "lora_B") | |
| if b_key not in lora_state: | |
| continue | |
| param_key = key.replace("lora_A.weight", "weight") | |
| param_key = param_key.removeprefix("diffusion_model.") | |
| param_key = param_key.replace(".lora_A", "") | |
| if param_key not in params: | |
| continue | |
| lora_A = lora_state[key].to(device=params[param_key].device, dtype=params[param_key].dtype) | |
| lora_B = lora_state[b_key].to(device=params[param_key].device, dtype=params[param_key].dtype) | |
| params[param_key].data -= (lora_B @ lora_A) * lora_scale | |
| # Save to a temporary unique file | |
| os.makedirs("/tmp/ernie_outputs", exist_ok=True) | |
| out_path = f"/tmp/ernie_outputs/{uuid.uuid4()}.png" | |
| image.save(out_path) | |
| return FileData(path=out_path) | |
| from fastapi.responses import HTMLResponse | |
| async def homepage(): | |
| """Serve the custom frontend HTML.""" | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| app.launch(show_error=True, mcp_server=True) | |