Spaces:
Runtime error
Runtime error
File size: 5,950 Bytes
9b1b72d ac9c0ff 9b1b72d 1b861c3 9b1b72d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
"""
Minimal Gradio wrapper for the given Qwen-Image-Edit inference script.
Features:
- Loads the model once and reuses it.
- Inputs: image, edit prompt, cond_b, cond_delta, optional model path.
- Matches your original settings (size 1024, steps=24, true_cfg_scale=4.0,
fixed seed=42, and the same GRAG scale structure repeated 60 times).
Run:
pip install gradio pillow torch
# plus your project deps providing hacked_models/* and model weights
python gradio_qwen_edit_minimal.py
Then open the local URL printed by Gradio.
"""
import os
from typing import Optional
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import snapshot_download
import os
# --- your project imports (as in the original script) ---
from hacked_models.scheduler import FlowMatchEulerDiscreteScheduler
from hacked_models.pipeline import QwenImageEditPipeline
from hacked_models.models import QwenImageTransformer2DModel
from hacked_models.utils import seed_everything
from huggingface_hub import snapshot_download
from requests.exceptions import ChunkedEncodingError
from urllib3.exceptions import ProtocolError
import os, time
def robust_snapshot_download(repo_id, local_dir, token=None, retries=5):
os.makedirs(local_dir, exist_ok=True)
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # 可选:更稳更快
last_err = None
for i in range(retries):
try:
return snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
resume_download=True, # 断点续传
use_auth_token=token,
max_workers=1 # 并行下载分片
)
except (ChunkedEncodingError, ProtocolError) as e:
last_err = e
wait = min(2**i, 30)
print(f"[download] network error {i+1}/{retries}: {e}; retry in {wait}s", flush=True)
time.sleep(wait)
raise RuntimeError(f"Download failed after {retries} retries: {last_err}")
# -----------------------------
# Global state
# -----------------------------
_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_DTYPE = torch.bfloat16 if _DEVICE == "cuda" else torch.float32
_PIPELINE: Optional[QwenImageEditPipeline] = None
_LOADED_MODEL_PATH: Optional[str] = None
def _load_pipeline(model_path: str) -> QwenImageEditPipeline:
"""Load (or reuse) the pipeline for the given model_path."""
global _PIPELINE, _LOADED_MODEL_PATH
if _PIPELINE is not None and _LOADED_MODEL_PATH == model_path:
return _PIPELINE
# Set seed once (matches original)
seed_everything(42)
# Load components
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
os.path.join(model_path, "scheduler"), torch_dtype=_DTYPE
)
transformer = QwenImageTransformer2DModel.from_pretrained(
os.path.join(model_path, "transformer"), torch_dtype=_DTYPE
)
pipe = QwenImageEditPipeline.from_pretrained(
model_path, torch_dtype=_DTYPE, scheduler=scheduler, transformer=transformer
)
pipe.set_progress_bar_config(disable=None)
pipe.to(_DTYPE)
pipe.to(_DEVICE)
_PIPELINE = pipe
_LOADED_MODEL_PATH = model_path
return pipe
def _build_grag_scale(cond_b: float, cond_delta: float, repeats: int = 60):
"""Replicates your original GRAG schedule structure.
Each element is: ((512, 1.0, 1.0), (4096, cond_b, cond_delta))
"""
return [((512, 1.0, 1.0), (4096, cond_b, cond_delta))] * repeats
def predict(
image: Image.Image,
edit_prompt: str,
cond_b: float,
cond_delta: float,
):
if image is None or not edit_prompt:
return None
# Match original preprocessing
input_image = image.convert("RGB").resize((1024, 1024))
inputs = {
"image": input_image,
"prompt": edit_prompt,
"generator": torch.manual_seed(42),
"true_cfg_scale": 4.0,
"negative_prompt": " ",
"num_inference_steps": 24,
"return_dict": False,
"grag_scale": _build_grag_scale(cond_b, cond_delta, repeats=60),
}
with torch.inference_mode():
image_batch, x0_images, saved_outputs = pipe(**inputs)
# Return the first image (same as original save behavior)
return image_batch[0]
model_dir = "Qwen-Image-Edit"
repo_id = "Qwen/Qwen-Image-Edit"
if not os.path.exists(model_dir) or not os.listdir(model_dir):
robust_snapshot_download(repo_id, model_dir, token=os.getenv("HF_TOKEN"))
print(f"Model downloaded to {model_dir}")
else:
print(f"Model already exists at {model_dir}")
pipe = _load_pipeline(model_dir)
with gr.Blocks(title="Qwen Image Edit — Minimal GRAG Demo") as demo:
gr.Markdown("# Qwen Image Edit — Minimal GRAG Demo\nUpload an image, enter your edit instruction, and set GRAG params.")
with gr.Row():
in_image = gr.Image(label="Input Image", type="pil")
out_image = gr.Image(label="Edited Output", type="pil")
edit_prompt = gr.Textbox(label="Edit Instruction", placeholder="e.g., Put a pair of black-framed glasses on him.")
with gr.Row():
cond_b = gr.Slider(label="cond_b", minimum=0.8, maximum=2.0, value=1.0, step=0.01)
cond_delta = gr.Slider(label="cond_delta", minimum=0.8, maximum=2.0, value=1.0, step=0.01)
run_btn = gr.Button("Run Edit")
run_btn.click(
fn=predict,
inputs=[in_image, edit_prompt, cond_b, cond_delta],
outputs=[out_image],
api_name="run_edit",
)
gr.Markdown(
"""
**Notes**
- Uses fixed seed=42 and num_inference_steps=24 to match your script.
- Resizes the input to 1024×1024 before inference (as in your code).
- `grag_scale` is built as a list of length 60 with the same tuples.
- Automatically chooses CUDA if available; otherwise runs on CPU.
"""
)
if __name__ == "__main__":
demo.queue().launch(share=True)
|