Try-on-app / app.py
Johdw's picture
Update app.py
d2438ed verified
# app.py
# THE GUARANTEED WORKING APPLICATION CODE
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageFilter, ImageOps, ImageChops
import requests
from io import BytesIO
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# ==================================================================================
# Step 1: Application Setup & Loading the HIGH-QUALITY AI Model
# ==================================================================================
print("⏳ Initializing The Final Quality Edition...")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"; TARGET_SIZE = (512, 512)
SAM_MODEL_TYPE = "vit_h"; SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
SAM_DOWNLOAD_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_CHECKPOINT_PATH):
print(f"Downloading HIGH-QUALITY Segment Anything Model..."); r = requests.get(SAM_DOWNLOAD_URL, stream=True, timeout=120); r.raise_for_status()
with open(SAM_CHECKPOINT_PATH, "wb") as f:
for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
from segment_anything import sam_model_registry, SamPredictor
try:
print(f"⏳ Loading SAM model..."); sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
sam_predictor = SamPredictor(sam); print("βœ… High-Quality SAM model loaded.")
except Exception as e: raise gr.Error(f"Fatal: Could not load SAM model. Error: {e}")
# ==================================================================================
# Step 2: Core Functions
# ==================================================================================
def generate_precise_mask(image: Image.Image, progress: gr.Progress):
progress(0.3, desc="πŸ€– Generating high-quality mask..."); image_np = np.array(image); sam_predictor.set_image(image_np)
h, w, _ = image_np.shape
input_points = np.array([[w * 0.40, h * 0.45], [w * 0.60, h * 0.45], [w * 0.5, h * 0.25]]); input_labels = np.array([1, 1, 0])
masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False)
return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1))
def create_perfect_result(fabric_orig, person_base, mask, scale_factor=1.0):
base_size=int(person_base.width/4); sw=max(1,int(base_size*scale_factor)); fw,fh=fabric_orig.size; sh=max(1,int(fw>0 and fh*(sw/fw)or 0))
s=fabric_orig.resize((sw,sh),Image.LANCZOS); t=Image.new('RGB',person_base.size)
for i in range(0,person_base.width,sw):
for j in range(0,person_base.height,sh): t.paste(s,(i,j))
lm=ImageOps.grayscale(person_base).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm); final=person_base.copy(); final.paste(shaded,(0,0),mask=mask)
return final
def load_image_from_url(url):
try: r = requests.get(url, stream=True, timeout=10); r.raise_for_status(); return Image.open(BytesIO(r.content)).convert("RGB")
except: return None
def generate_automatic_tryon(p_img_upload, p_img_url, f_img_upload, f_img_url, progress=gr.Progress(track_tqdm=True)):
progress(0.05, desc="Loading images..."); person_img = p_img_upload if p_img_upload is not None else load_image_from_url(p_img_url)
fabric_img = f_img_upload if f_img_upload is not None else load_image_from_url(f_img_url)
if person_img is None or fabric_img is None: raise gr.Error("Missing an image.")
person_resized = person_img.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
mask = generate_precise_mask(person_resized, progress)
progress(0.8, desc="🎨 Applying fabric and lighting...");
results = [create_perfect_result(fabric_img, person_resized, mask, sf) for sf in [0.75, 0.4, 1.2]]
progress(1.0, desc="βœ… Done!")
return results, mask, mask
# ==================================================================================
# Step 3: Gradio User Interface
# ==================================================================================
with gr.Blocks(theme=gr.themes.Soft(), title="Virtual Try-On: Final Quality Edition") as demo:
gr.Markdown("# πŸ‘” Virtual Try-On: The Final Quality Edition")
with gr.Row():
with gr.Column(scale=2):
p_upload = gr.Image(type="pil", label="Person in Suit")
p_url = gr.Textbox(label="Person URL")
f_upload = gr.Image(type="pil", label="Fabric Pattern")
f_url = gr.Textbox(label="Fabric URL")
btn = gr.Button("Generate Perfect Result", variant="primary")
with gr.Column(scale=3):
gallery = gr.Gallery(columns=3, object_fit="cover", height=512)
mask_display = gr.Image(label="The Final, Precise Mask Used")
btn.click(
fn=generate_automatic_tryon,
inputs=[p_upload, p_url, f_upload, f_url],
outputs=[gallery, mask_display, mask_display]
)
demo.launch()