Johdw commited on
Commit
d2438ed
Β·
verified Β·
1 Parent(s): 3f36864

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -79
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # app.py
2
- # THE DEFINITIVE, GUARANTEED WORKING VERSION - With a Secure Custom API
3
 
4
  import gradio as gr
5
  import torch
@@ -8,25 +8,14 @@ from PIL import Image, ImageFilter, ImageOps, ImageChops
8
  import requests
9
  from io import BytesIO
10
  import os
11
- import base64
12
-
13
- from fastapi import FastAPI, Request, HTTPException
14
- from pydantic import BaseModel
15
 
16
  import warnings
17
  warnings.filterwarnings("ignore", category=UserWarning)
18
 
19
  # ==================================================================================
20
- # === A) YOUR SECRET API KEY ===
21
- # ==================================================================================
22
- # This is the key that your Laravel application MUST send to use the API.
23
- # The public UI is not protected, but this API endpoint is.
24
- API_KEY = "SuperSecretKeyForLaravelApp!@#ChangeMe123"
25
-
26
- # ==================================================================================
27
- # Step 1: Application Setup (Unchanged)
28
  # ==================================================================================
29
- print("⏳ Initializing The Final Quality Edition with Secure API...")
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"; TARGET_SIZE = (512, 512)
31
 
32
  SAM_MODEL_TYPE = "vit_h"; SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
@@ -43,85 +32,58 @@ try:
43
  except Exception as e: raise gr.Error(f"Fatal: Could not load SAM model. Error: {e}")
44
 
45
  # ==================================================================================
46
- # Step 2: Core Functions (Unchanged - These work perfectly)
47
  # ==================================================================================
48
- def generate_precise_mask(image: Image.Image):
49
- image_np = np.array(image); sam_predictor.set_image(image_np); h, w, _ = image_np.shape
50
- input_points = np.array([[w*0.4, h*0.45], [w*0.6, h*0.45], [w*0.5, h*0.25]]); input_labels = np.array([1,1,0])
 
51
  masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False)
52
  return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1))
53
 
54
- def create_perfect_result(fabric, person, mask):
55
- # This combines the three scale variations into one efficient call
56
- results = []
57
- for sf in [0.75, 0.4, 1.2]:
58
- base_size=int(person.width/4); sw=max(1,int(base_size*sf)); fw,fh=fabric.size; sh=max(1,int(fh*(sw/fw))if fw>0 else 0)
59
- s=fabric.resize((sw,sh),Image.LANCZOS); t=Image.new('RGB',person.size)
60
- for i in range(0,person.width,sw):
61
- for j in range(0,person.height,sh): t.paste(s,(i,j))
62
- lm=ImageOps.grayscale(person).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm)
63
- final=person.copy(); final.paste(shaded,(0,0),mask=mask)
64
- results.append(final)
65
- return results
66
 
67
- def load_image(url):
68
- try:
69
- r = requests.get(url, stream=True, timeout=10); r.raise_for_status(); return Image.open(BytesIO(r.content)).convert("RGB")
70
  except: return None
71
 
72
- # ==================================================================================
73
- # === B) The New, Custom, GUARANTEED-TO-WORK FastAPI Endpoint ===
74
- # ==================================================================================
75
- app = FastAPI()
76
- class ApiInput(BaseModel): person_url: str; fabric_url: str
77
-
78
- @app.post("/generate-api")
79
- async def api_generate(request: Request, inputs: ApiInput):
80
- if request.headers.get("x-api-key") != API_KEY:
81
- raise HTTPException(status_code=401, detail="Unauthorized: Invalid API Key")
82
-
83
- person = load_image(inputs.person_url)
84
- fabric = load_image(inputs.fabric_url)
85
- if person is None or fabric is None:
86
- raise HTTPException(status_code=400, detail="Could not load image from URL")
87
-
88
- person = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
89
- mask = generate_precise_mask(person)
90
- results = create_perfect_result(fabric, person, mask)
91
-
92
- output_images_base64 = []
93
- for img in results:
94
- buffered = BytesIO()
95
- img.save(buffered, format="PNG")
96
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
97
- output_images_base64.append(f"data:image/png;base64,{img_str}")
98
-
99
- return {"results": output_images_base64}
100
 
101
  # ==================================================================================
102
- # Step 3: The Gradio UI for manual use (It now uses the optimized functions)
103
  # ==================================================================================
104
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
105
- gr.Markdown("# πŸ‘” Virtual Try-On: The Final Edition")
106
-
107
  with gr.Row():
108
  with gr.Column(scale=2):
109
- p_url = gr.Textbox(label="Person in Suit URL", value="https://img.freepik.com/premium-photo/business-man-suit-white-transparent-background_457222-10395.jpg")
110
- f_url = gr.Textbox(label="Fabric Pattern URL", value="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcT3ajOubmAW68ZHz8aFu4WjfwBxPk1bUsd3ng&s")
 
 
111
  btn = gr.Button("Generate Perfect Result", variant="primary")
112
  with gr.Column(scale=3):
113
  gallery = gr.Gallery(columns=3, object_fit="cover", height=512)
 
114
 
115
- def ui_fn(person_url, fabric_url):
116
- person = load_image(person_url)
117
- fabric = load_image(fabric_url)
118
- if person is None or fabric is None: raise gr.Error("Missing an image.")
119
- person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
120
- mask = generate_precise_mask(person_resized)
121
- results = create_perfect_result(fabric, person_resized, mask)
122
- return results
123
-
124
- btn.click(fn=ui_fn, inputs=[p_url, f_url], outputs=[gallery])
125
 
126
- # Mount the FastAPI app and the Gradio app together
127
- app = gr.mount_gradio_app(app, demo, path="/")
 
1
  # app.py
2
+ # THE GUARANTEED WORKING APPLICATION CODE
3
 
4
  import gradio as gr
5
  import torch
 
8
  import requests
9
  from io import BytesIO
10
  import os
 
 
 
 
11
 
12
  import warnings
13
  warnings.filterwarnings("ignore", category=UserWarning)
14
 
15
  # ==================================================================================
16
+ # Step 1: Application Setup & Loading the HIGH-QUALITY AI Model
 
 
 
 
 
 
 
17
  # ==================================================================================
18
+ print("⏳ Initializing The Final Quality Edition...")
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"; TARGET_SIZE = (512, 512)
20
 
21
  SAM_MODEL_TYPE = "vit_h"; SAM_CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
 
32
  except Exception as e: raise gr.Error(f"Fatal: Could not load SAM model. Error: {e}")
33
 
34
  # ==================================================================================
35
+ # Step 2: Core Functions
36
  # ==================================================================================
37
+ def generate_precise_mask(image: Image.Image, progress: gr.Progress):
38
+ progress(0.3, desc="πŸ€– Generating high-quality mask..."); image_np = np.array(image); sam_predictor.set_image(image_np)
39
+ h, w, _ = image_np.shape
40
+ input_points = np.array([[w * 0.40, h * 0.45], [w * 0.60, h * 0.45], [w * 0.5, h * 0.25]]); input_labels = np.array([1, 1, 0])
41
  masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False)
42
  return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1))
43
 
44
+ def create_perfect_result(fabric_orig, person_base, mask, scale_factor=1.0):
45
+ base_size=int(person_base.width/4); sw=max(1,int(base_size*scale_factor)); fw,fh=fabric_orig.size; sh=max(1,int(fw>0 and fh*(sw/fw)or 0))
46
+ s=fabric_orig.resize((sw,sh),Image.LANCZOS); t=Image.new('RGB',person_base.size)
47
+ for i in range(0,person_base.width,sw):
48
+ for j in range(0,person_base.height,sh): t.paste(s,(i,j))
49
+ lm=ImageOps.grayscale(person_base).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm); final=person_base.copy(); final.paste(shaded,(0,0),mask=mask)
50
+ return final
 
 
 
 
 
51
 
52
+ def load_image_from_url(url):
53
+ try: r = requests.get(url, stream=True, timeout=10); r.raise_for_status(); return Image.open(BytesIO(r.content)).convert("RGB")
 
54
  except: return None
55
 
56
+ def generate_automatic_tryon(p_img_upload, p_img_url, f_img_upload, f_img_url, progress=gr.Progress(track_tqdm=True)):
57
+ progress(0.05, desc="Loading images..."); person_img = p_img_upload if p_img_upload is not None else load_image_from_url(p_img_url)
58
+ fabric_img = f_img_upload if f_img_upload is not None else load_image_from_url(f_img_url)
59
+ if person_img is None or fabric_img is None: raise gr.Error("Missing an image.")
60
+ person_resized = person_img.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
61
+ mask = generate_precise_mask(person_resized, progress)
62
+ progress(0.8, desc="🎨 Applying fabric and lighting...");
63
+ results = [create_perfect_result(fabric_img, person_resized, mask, sf) for sf in [0.75, 0.4, 1.2]]
64
+ progress(1.0, desc="βœ… Done!")
65
+ return results, mask, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # ==================================================================================
68
+ # Step 3: Gradio User Interface
69
  # ==================================================================================
70
+ with gr.Blocks(theme=gr.themes.Soft(), title="Virtual Try-On: Final Quality Edition") as demo:
71
+ gr.Markdown("# πŸ‘” Virtual Try-On: The Final Quality Edition")
 
72
  with gr.Row():
73
  with gr.Column(scale=2):
74
+ p_upload = gr.Image(type="pil", label="Person in Suit")
75
+ p_url = gr.Textbox(label="Person URL")
76
+ f_upload = gr.Image(type="pil", label="Fabric Pattern")
77
+ f_url = gr.Textbox(label="Fabric URL")
78
  btn = gr.Button("Generate Perfect Result", variant="primary")
79
  with gr.Column(scale=3):
80
  gallery = gr.Gallery(columns=3, object_fit="cover", height=512)
81
+ mask_display = gr.Image(label="The Final, Precise Mask Used")
82
 
83
+ btn.click(
84
+ fn=generate_automatic_tryon,
85
+ inputs=[p_upload, p_url, f_upload, f_url],
86
+ outputs=[gallery, mask_display, mask_display]
87
+ )
 
 
 
 
 
88
 
89
+ demo.launch()