dschandra commited on
Commit
756a125
·
verified ·
1 Parent(s): 2a43965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -38
app.py CHANGED
@@ -1,44 +1,126 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import CLIPProcessor, CLIPModel
3
- from PIL import Image
4
- import torch
5
-
6
- # Load the CLIP model and processor from Hugging Face
7
- model_name = "openai/clip-vit-base-patch32"
8
- model = CLIPModel.from_pretrained(model_name)
9
- processor = CLIPProcessor.from_pretrained(model_name)
10
-
11
- def generate_outfit(base_image, garment_image):
12
- # Open the images using PIL
13
- base_image = Image.open(base_image).convert("RGB")
14
- garment_image = Image.open(garment_image).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Process the images with the CLIP model
17
- inputs = processor(
18
- text=["a person wearing a garment"],
19
- images=[base_image, garment_image],
20
- return_tensors="pt",
21
- padding=True # Ensure padding is applied for batched inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
-
24
- # Perform inference
25
- with torch.no_grad():
26
- outputs = model(**inputs)
27
-
28
- # Get the similarity score between the base image and the garment image
29
- similarity_score = outputs.logits_per_image.item()
30
-
31
- # Return the similarity score
32
- return f"Similarity Score: {similarity_score:.4f}"
33
-
34
- # Gradio Interface
35
- demo = gr.Interface(
36
- fn=generate_outfit,
37
- inputs=[gr.Image(type="filepath", label="Base Image"), gr.Image(type="filepath", label="Garment Image")],
38
- outputs="text",
39
- title="Outfit Generator",
40
- description="Upload a base image and a garment image to generate outfit suggestions using CLIP."
41
- )
42
 
43
  if __name__ == "__main__":
 
44
  demo.launch()
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import json
5
+ import random
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import requests
8
+ import base64
9
  import gradio as gr
10
+
11
+ # Define the path to the default model
12
+ default_model = os.path.join(os.path.dirname(__file__), "models/eva/Eva_0.png")
13
+
14
+ # Map of AI models with their corresponding file paths
15
+ MODEL_MAP = {
16
+ "AI Model Rouyan_0": 'models/rouyan_new/Rouyan_0.png',
17
+ "AI Model Rouyan_1": 'models/rouyan_new/Rouyan_1.png',
18
+ "AI Model Rouyan_2": 'models/rouyan_new/Rouyan_2.png',
19
+ "AI Model Eva_0": 'models/eva/Eva_0.png',
20
+ "AI Model Eva_1": 'models/eva/Eva_1.png',
21
+ "AI Model Simon_0": 'models/simon_online/Simon_0.png',
22
+ "AI Model Simon_1": 'models/simon_online/Simon_1.png',
23
+ "AI Model Xuanxuan_0": 'models/xiaoxuan_online/Xuanxuan_0.png',
24
+ "AI Model Xuanxuan_1": 'models/xiaoxuan_online/Xuanxuan_1.png',
25
+ "AI Model Xuanxuan_2": 'models/xiaoxuan_online/Xuanxuan_2.png',
26
+ "AI Model Yaqi_0": 'models/yaqi/Yaqi_0.png',
27
+ "AI Model Yaqi_1": 'models/yaqi/Yaqi_1.png',
28
+ "AI Model Yaqi_2": 'models/yaqi/Yaqi_2.png',
29
+ "AI Model Yaqi_3": 'models/yaqi/Yaqi_3.png',
30
+ "AI Model Yifeng_0": 'models/yifeng_online/Yifeng_0.png',
31
+ "AI Model Yifeng_1": 'models/yifeng_online/Yifeng_1.png',
32
+ "AI Model Yifeng_2": 'models/yifeng_online/Yifeng_2.png',
33
+ "AI Model Yifeng_3": 'models/yifeng_online/Yifeng_3.png',
34
+ }
35
+
36
+ def add_watermark(image):
37
+ """
38
+ Adds a watermark to the provided image.
39
+ """
40
+ height, width, _ = image.shape
41
+ cv2.putText(image, 'Powered by OutfitAnyone', (int(0.3 * width), height - 20),
42
+ cv2.FONT_HERSHEY_PLAIN, 2, (128, 128, 128), 2, cv2.LINE_AA)
43
+ return image
44
+
45
+ def get_tryon_result(model_name, top_garment, bottom_garment=None):
46
+ """
47
+ Processes the virtual try-on result by sending a request to the server.
48
+ """
49
+ # Format the model name for the server request
50
+ model_key = "AI Model " + model_name.split("/")[-1].split(".")[0]
51
+ print(f"Selected model: {model_key}")
52
+
53
+ # Encode the garments as base64
54
+ encoded_top = base64.b64encode(cv2.imencode('.jpg', top_garment)[1].tobytes()).decode('utf-8')
55
+ encoded_bottom = base64.b64encode(cv2.imencode('.jpg', bottom_garment)[1].tobytes()).decode('utf-8') if bottom_garment else ''
56
+
57
+ # Server request setup
58
+ server_url = os.environ.get('OA_IP_ADDRESS', 'http://localhost:5000') # Default to localhost if environment variable is not set
59
+ headers = {'Content-Type': 'application/json'}
60
+ payload = {
61
+ "garment1": encoded_top,
62
+ "garment2": encoded_bottom,
63
+ "model_name": model_key,
64
+ "seed": random.randint(0, 99999999)
65
+ }
66
+
67
+ # Send the request
68
+ response = requests.post(server_url, headers=headers, data=json.dumps(payload))
69
+
70
+ if response.status_code == 200:
71
+ result = response.json()
72
+ result_img = cv2.imdecode(np.frombuffer(base64.b64decode(result['images'][0]), np.uint8), cv2.IMREAD_UNCHANGED)
73
+ final_img = add_watermark(result_img)
74
+ return final_img
75
+ else:
76
+ print(f"Error: Server responded with status code {response.status_code}")
77
+ return None
78
+
79
+ # Set up the Gradio interface
80
+ with gr.Blocks(css=".output-image, .input-image, .image-preview {height: 400px !important}") as demo:
81
+ gr.HTML("""
82
+ <div style="text-align: center;">
83
+ <h1>Outfit Anyone: Virtual Try-On</h1>
84
+ <h4>v1.0</h4>
85
+ <p>Upload your garments and choose a model to see the virtual try-on.</p>
86
+ </div>
87
+ """)
88
 
89
+ with gr.Row():
90
+ with gr.Column():
91
+ model_selector = gr.Image(sources='clipboard', type="filepath", label="Model", value=default_model)
92
+ example_models = gr.Examples(inputs=model_selector,
93
+ examples=[MODEL_MAP['AI Model Rouyan_0'], MODEL_MAP['AI Model Eva_0']],
94
+ examples_per_page=4)
95
+ with gr.Column():
96
+ gr.HTML("<h3>Select Garments for Virtual Try-On</h3>")
97
+ top_garment_input = gr.Image(sources='upload', type="numpy", label="Top Garment")
98
+ bottom_garment_input = gr.Image(sources='upload', type="numpy", label="Bottom Garment (Optional)")
99
+ example_top_garments = gr.Examples(inputs=top_garment_input,
100
+ examples=[os.path.join(os.path.dirname(__file__), "garments/top1.jpg")],
101
+ examples_per_page=5)
102
+ example_bottom_garments = gr.Examples(inputs=bottom_garment_input,
103
+ examples=[os.path.join(os.path.dirname(__file__), "garments/bottom1.jpg")],
104
+ examples_per_page=5)
105
+ generate_button = gr.Button(value="Generate Outfit")
106
+
107
+ with gr.Column():
108
+ result_display = gr.Image()
109
+
110
+ generate_button.click(fn=get_tryon_result,
111
+ inputs=[model_selector, top_garment_input, bottom_garment_input],
112
+ outputs=[result_display])
113
+
114
+ gr.Markdown("## Example Outputs")
115
+ with gr.Row():
116
+ ref_image = gr.Image(label="Model Example", value="examples/model_example.jpg")
117
+ garment_example = gr.Image(label="Garment Example", value="examples/garment_example.jpg")
118
+ result_example = gr.Image(label="Result Example", value="examples/result_example.jpg")
119
+ gr.Examples(
120
+ examples=[["examples/model_example.jpg", "examples/garment_example.jpg", "examples/result_example.jpg"]],
121
+ inputs=[ref_image, garment_example, result_example],
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
+ demo.queue(max_size=10)
126
  demo.launch()