# app.py # THE GUARANTEED WORKING APPLICATION CODE import gradio as gr import torch import numpy as np from PIL import Image, ImageFilter, ImageOps, ImageChops import requests from io import BytesIO import os import warnings warnings.filterwarnings("ignore", category=UserWarning) # ================================================================================== # Step 1: Application Setup & Loading the HIGH-QUALITY AI Model # ================================================================================== print("⏳ Initializing The Final Quality Edition...") DEVICE = "cuda" if torch.cuda.is_available() else "cpu"; TARGET_SIZE = (512, 512) SAM_MODEL_TYPE = "vit_h"; SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" SAM_DOWNLOAD_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" if not os.path.exists(SAM_CHECKPOINT_PATH): print(f"Downloading HIGH-QUALITY Segment Anything Model..."); r = requests.get(SAM_DOWNLOAD_URL, stream=True, timeout=120); r.raise_for_status() with open(SAM_CHECKPOINT_PATH, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) from segment_anything import sam_model_registry, SamPredictor try: print(f"⏳ Loading SAM model..."); sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE) sam_predictor = SamPredictor(sam); print("✅ High-Quality SAM model loaded.") except Exception as e: raise gr.Error(f"Fatal: Could not load SAM model. Error: {e}") # ================================================================================== # Step 2: Core Functions # ================================================================================== def generate_precise_mask(image: Image.Image, progress: gr.Progress): progress(0.3, desc="🤖 Generating high-quality mask..."); image_np = np.array(image); sam_predictor.set_image(image_np) h, w, _ = image_np.shape input_points = np.array([[w * 0.40, h * 0.45], [w * 0.60, h * 0.45], [w * 0.5, h * 0.25]]); input_labels = np.array([1, 1, 0]) masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False) return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1)) def create_perfect_result(fabric_orig, person_base, mask, scale_factor=1.0): base_size=int(person_base.width/4); sw=max(1,int(base_size*scale_factor)); fw,fh=fabric_orig.size; sh=max(1,int(fw>0 and fh*(sw/fw)or 0)) s=fabric_orig.resize((sw,sh),Image.LANCZOS); t=Image.new('RGB',person_base.size) for i in range(0,person_base.width,sw): for j in range(0,person_base.height,sh): t.paste(s,(i,j)) lm=ImageOps.grayscale(person_base).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm); final=person_base.copy(); final.paste(shaded,(0,0),mask=mask) return final def load_image_from_url(url): try: r = requests.get(url, stream=True, timeout=10); r.raise_for_status(); return Image.open(BytesIO(r.content)).convert("RGB") except: return None def generate_automatic_tryon(p_img_upload, p_img_url, f_img_upload, f_img_url, progress=gr.Progress(track_tqdm=True)): progress(0.05, desc="Loading images..."); person_img = p_img_upload if p_img_upload is not None else load_image_from_url(p_img_url) fabric_img = f_img_upload if f_img_upload is not None else load_image_from_url(f_img_url) if person_img is None or fabric_img is None: raise gr.Error("Missing an image.") person_resized = person_img.resize(TARGET_SIZE, Image.Resampling.LANCZOS) mask = generate_precise_mask(person_resized, progress) progress(0.8, desc="🎨 Applying fabric and lighting..."); results = [create_perfect_result(fabric_img, person_resized, mask, sf) for sf in [0.75, 0.4, 1.2]] progress(1.0, desc="✅ Done!") return results, mask, mask # ================================================================================== # Step 3: Gradio User Interface # ================================================================================== with gr.Blocks(theme=gr.themes.Soft(), title="Virtual Try-On: Final Quality Edition") as demo: gr.Markdown("# 👔 Virtual Try-On: The Final Quality Edition") with gr.Row(): with gr.Column(scale=2): p_upload = gr.Image(type="pil", label="Person in Suit") p_url = gr.Textbox(label="Person URL") f_upload = gr.Image(type="pil", label="Fabric Pattern") f_url = gr.Textbox(label="Fabric URL") btn = gr.Button("Generate Perfect Result", variant="primary") with gr.Column(scale=3): gallery = gr.Gallery(columns=3, object_fit="cover", height=512) mask_display = gr.Image(label="The Final, Precise Mask Used") btn.click( fn=generate_automatic_tryon, inputs=[p_upload, p_url, f_upload, f_url], outputs=[gallery, mask_display, mask_display] ) demo.launch()