Spaces:
Sleeping
Sleeping
| # app.py | |
| # THE GUARANTEED WORKING APPLICATION CODE | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageFilter, ImageOps, ImageChops | |
| import requests | |
| from io import BytesIO | |
| import os | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # ================================================================================== | |
| # Step 1: Application Setup & Loading the HIGH-QUALITY AI Model | |
| # ================================================================================== | |
| print("β³ Initializing The Final Quality Edition...") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu"; TARGET_SIZE = (512, 512) | |
| SAM_MODEL_TYPE = "vit_h"; SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" | |
| SAM_DOWNLOAD_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| if not os.path.exists(SAM_CHECKPOINT_PATH): | |
| print(f"Downloading HIGH-QUALITY Segment Anything Model..."); r = requests.get(SAM_DOWNLOAD_URL, stream=True, timeout=120); r.raise_for_status() | |
| with open(SAM_CHECKPOINT_PATH, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=8192): f.write(chunk) | |
| from segment_anything import sam_model_registry, SamPredictor | |
| try: | |
| print(f"β³ Loading SAM model..."); sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE) | |
| sam_predictor = SamPredictor(sam); print("β High-Quality SAM model loaded.") | |
| except Exception as e: raise gr.Error(f"Fatal: Could not load SAM model. Error: {e}") | |
| # ================================================================================== | |
| # Step 2: Core Functions | |
| # ================================================================================== | |
| def generate_precise_mask(image: Image.Image, progress: gr.Progress): | |
| progress(0.3, desc="π€ Generating high-quality mask..."); image_np = np.array(image); sam_predictor.set_image(image_np) | |
| h, w, _ = image_np.shape | |
| input_points = np.array([[w * 0.40, h * 0.45], [w * 0.60, h * 0.45], [w * 0.5, h * 0.25]]); input_labels = np.array([1, 1, 0]) | |
| masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False) | |
| return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1)) | |
| def create_perfect_result(fabric_orig, person_base, mask, scale_factor=1.0): | |
| base_size=int(person_base.width/4); sw=max(1,int(base_size*scale_factor)); fw,fh=fabric_orig.size; sh=max(1,int(fw>0 and fh*(sw/fw)or 0)) | |
| s=fabric_orig.resize((sw,sh),Image.LANCZOS); t=Image.new('RGB',person_base.size) | |
| for i in range(0,person_base.width,sw): | |
| for j in range(0,person_base.height,sh): t.paste(s,(i,j)) | |
| lm=ImageOps.grayscale(person_base).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm); final=person_base.copy(); final.paste(shaded,(0,0),mask=mask) | |
| return final | |
| def load_image_from_url(url): | |
| try: r = requests.get(url, stream=True, timeout=10); r.raise_for_status(); return Image.open(BytesIO(r.content)).convert("RGB") | |
| except: return None | |
| def generate_automatic_tryon(p_img_upload, p_img_url, f_img_upload, f_img_url, progress=gr.Progress(track_tqdm=True)): | |
| progress(0.05, desc="Loading images..."); person_img = p_img_upload if p_img_upload is not None else load_image_from_url(p_img_url) | |
| fabric_img = f_img_upload if f_img_upload is not None else load_image_from_url(f_img_url) | |
| if person_img is None or fabric_img is None: raise gr.Error("Missing an image.") | |
| person_resized = person_img.resize(TARGET_SIZE, Image.Resampling.LANCZOS) | |
| mask = generate_precise_mask(person_resized, progress) | |
| progress(0.8, desc="π¨ Applying fabric and lighting..."); | |
| results = [create_perfect_result(fabric_img, person_resized, mask, sf) for sf in [0.75, 0.4, 1.2]] | |
| progress(1.0, desc="β Done!") | |
| return results, mask, mask | |
| # ================================================================================== | |
| # Step 3: Gradio User Interface | |
| # ================================================================================== | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Virtual Try-On: Final Quality Edition") as demo: | |
| gr.Markdown("# π Virtual Try-On: The Final Quality Edition") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| p_upload = gr.Image(type="pil", label="Person in Suit") | |
| p_url = gr.Textbox(label="Person URL") | |
| f_upload = gr.Image(type="pil", label="Fabric Pattern") | |
| f_url = gr.Textbox(label="Fabric URL") | |
| btn = gr.Button("Generate Perfect Result", variant="primary") | |
| with gr.Column(scale=3): | |
| gallery = gr.Gallery(columns=3, object_fit="cover", height=512) | |
| mask_display = gr.Image(label="The Final, Precise Mask Used") | |
| btn.click( | |
| fn=generate_automatic_tryon, | |
| inputs=[p_upload, p_url, f_upload, f_url], | |
| outputs=[gallery, mask_display, mask_display] | |
| ) | |
| demo.launch() |