w-1 / app.py
Hana Celeste
Update app.py
e522c32 verified
import gradio as gr
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
import gc
import warnings
# Tắt cảnh báo transformers về tokenization spaces
warnings.filterwarnings(
"ignore",
category=FutureWarning,
module="transformers.tokenization_utils_base"
)
# =============================
# CONFIG
# =============================
MODEL_ID = "timbrooks/instruct-pix2pix"
MAX_IMAGE_SIZE = 512 # Có thể giảm xuống 384 nếu vẫn OOM
DEFAULT_STEPS = 20
DEFAULT_GUIDANCE = 7.5
MAX_PROMPT_LENGTH = 200 # Tăng lên để hỗ trợ prompt dài hơn (CLIP mặc định 77, nhưng ta xử lý được dài hơn)
# =============================
# LOAD MODEL (CPU only - NO CUDA/OFFLOAD)
# =============================
print("Loading model...")
try:
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to("cpu")
# Chỉ dùng cái này để tiết kiệm RAM, an toàn trên CPU
pipe.enable_attention_slicing()
# KHÔNG dùng bất kỳ offload nào: enable_model_cpu_offload, sequential_cpu_offload, v.v.
# Nếu bạn thấy dòng nào có "offload" hoặc "cuda" thì xóa ngay
print("Model loaded successfully!")
except Exception as e:
print(f"Model loading failed: {str(e)}")
raise
# =============================
# IMAGE EDIT FUNCTION
# =============================
def edit_image(image, prompt, negative_prompt, steps, guidance_scale):
if image is None:
return None, "Vui lòng upload ảnh trước nhé!"
try:
# Resize ảnh để tránh OOM
image = image.convert("RGB")
image.thumbnail((MAX_IMAGE_SIZE, MAX_IMAGE_SIZE))
print(f"Ảnh đã resize: {image.size}")
# Dọn RAM trước khi generate
gc.collect()
generator = torch.Generator(device="cpu").manual_seed(42)
# Xử lý prompt dài: cắt bớt nếu quá dài (tránh lỗi indexing)
if len(prompt.split()) > MAX_PROMPT_LENGTH:
prompt = " ".join(prompt.split()[:MAX_PROMPT_LENGTH])
status_msg = f"Prompt quá dài, đã cắt còn {MAX_PROMPT_LENGTH} từ."
else:
status_msg = "Đang xử lý..."
result = pipe(
prompt=prompt,
image=image,
negative_prompt=negative_prompt if negative_prompt.strip() else None,
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
generator=generator,
).images[0]
return result, status_msg + " Thành công!"
except RuntimeError as e:
if "out of memory" in str(e).lower():
return None, "❌ Hết RAM. Thử:\n• Ảnh nhỏ hơn\n• Giảm steps xuống 10-15\n• Đóng tab khác"
if "CUDA" in str(e):
return None, "Lỗi CUDA - Đã cấu hình CPU only, nếu vẫn lỗi hãy báo mình nhé!"
return None, f"Lỗi runtime: {str(e)}"
except Exception as e:
return None, f"Lỗi: {str(e)}"
# =============================
# GRADIO INTERFACE
# =============================
with gr.Blocks(title="InstructPix2Pix - CPU Edition") as demo:
gr.Markdown(
"""
# 🖌 InstructPix2Pix (CPU version)
Upload ảnh → Viết hướng dẫn chỉnh sửa (prompt dài cũng ok, mình tự cắt nếu cần) → Generate
⚠️ Chạy CPU nên chậm (~1-3 phút mỗi lần). Ảnh nhỏ + steps ít = nhanh hơn.
"""
)
with gr.Row():
input_image = gr.Image(label="Ảnh gốc", type="pil", image_mode="RGB", height=350)
output_image = gr.Image(label="Ảnh sau chỉnh sửa", type="pil", height=350)
prompt = gr.Textbox(
label="Prompt (hướng dẫn chỉnh sửa - có thể dài)",
placeholder="remove all clothing completely, keep face hair body proportions unchanged, high fidelity, realistic skin texture",
lines=5,
max_lines=10
)
negative = gr.Textbox(
label="Negative prompt (tránh những thứ này)",
value="blurry, low quality, deformed, bad anatomy, extra limbs, watermark, text",
lines=2
)
with gr.Accordion("Cài đặt nâng cao", open=False):
steps_slider = gr.Slider(10, 50, value=DEFAULT_STEPS, step=5, label="Số bước inference")
guidance = gr.Slider(1.0, 15.0, value=DEFAULT_GUIDANCE, step=0.5, label="Guidance scale")
btn = gr.Button("✨ Generate", variant="primary")
status = gr.Textbox(label="Trạng thái", interactive=False)
btn.click(
fn=edit_image,
inputs=[input_image, prompt, negative, steps_slider, guidance],
outputs=[output_image, status]
)
# Launch
demo.queue(max_size=3).launch(server_name="0.0.0.0", server_port=7860)