Alaiy commited on
Commit
f52c1ab
·
verified ·
1 Parent(s): a5d4ce8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from dotenv import load_dotenv
3
+ import requests
4
+ from flask import Flask, jsonify, request, send_file
5
+ from botocore.exceptions import ClientError
6
+ from botocore.client import Config
7
+ import boto3
8
+ from urllib.parse import urlparse
9
+ import os
10
+ from PIL import Image
11
+ from io import BytesIO
12
+ import uuid
13
+
14
+
15
+ load_dotenv()
16
+
17
+ ENDPOINT_URL = os.getenv("ENDPOINT_URL")
18
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
19
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
20
+ BUCKET_NAME = "techTailor"
21
+ s3_client = boto3.client(
22
+ "s3",
23
+ endpoint_url = ENDPOINT_URL,
24
+ aws_access_key_id = AWS_ACCESS_KEY_ID,
25
+ aws_secret_access_key = AWS_SECRET_ACCESS_KEY,
26
+ config=Config(signature_version='s3v4')
27
+ )
28
+ MODAL_INFERENCE_ENDPOINT_URL = os.getenv("MODAL_INFERENCE_ENDPOINT_URL")
29
+ app = Flask(__name__)
30
+
31
+ GARM_SAVE_DIR = "garment_images"
32
+ MODE_SAVE_DIR = "model_images"
33
+
34
+ garment_upload_dir = "gradio_demo_garment/"
35
+ model_upload_dir = "gradio_demo_model/"
36
+
37
+ def load_image_from_url(image_url):
38
+ try:
39
+ response = requests.get(image_url)
40
+ # Ensure the response is an image by checking the content type
41
+ if "image" in response.headers["Content-Type"]:
42
+ img = Image.open(BytesIO(response.content))
43
+ return img
44
+ else:
45
+ return None
46
+ except Exception as e:
47
+ print(f"Error loading image: {e}")
48
+ return None
49
+
50
+ def process_cloth_image(image_url):
51
+ if image_url:
52
+ try:
53
+ response = requests.get(image_url)
54
+ response.raise_for_status()
55
+ img = Image.open(BytesIO(response.content))
56
+ img = img.convert("RGB")
57
+ img_width, img_height = img.size
58
+ target_width = 768
59
+ target_height = 1024
60
+ scale_width = target_width / img_width
61
+ scale_height = target_height / img_height
62
+ scale_factor = min(scale_width, scale_height)
63
+ new_width = int(img_width * scale_factor)
64
+ new_height = int(img_height * scale_factor)
65
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
66
+ new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0))
67
+ left_padding = (target_width - new_width) // 2
68
+ top_padding = (target_height - new_height) // 2
69
+ new_img.paste(img, (left_padding, top_padding))
70
+ img_byte_array = BytesIO()
71
+ new_img.save(img_byte_array, format="JPEG")
72
+ img_byte_array.seek(0)
73
+ filename = f"{uuid.uuid4().hex}.jpg"
74
+ s3_client.upload_fileobj(img_byte_array, BUCKET_NAME, garment_upload_dir + filename, ExtraArgs={'ContentType': 'image/jpeg'})
75
+ garment_url = s3_client.generate_presigned_url(
76
+ 'get_object',
77
+ Params={'Bucket': BUCKET_NAME, 'Key': garment_upload_dir + filename},
78
+ ExpiresIn=3600
79
+ )
80
+ return garment_url
81
+
82
+ except requests.exceptions.RequestException as e:
83
+ return f"Error downloading image: {e}"
84
+ except Exception as e:
85
+ return f"Error processing image: {e}"
86
+ else:
87
+ return "No image provided"
88
+
89
+ def process_model_image(image):
90
+ img = image.convert("RGB")
91
+ img_width, img_height = img.size
92
+ target_width = 768
93
+ target_height = 1024
94
+ scale_width = target_width / img_width
95
+ scale_height = target_height / img_height
96
+ scale_factor = min(scale_width, scale_height)
97
+ new_width = int(img_width * scale_factor)
98
+ new_height = int(img_height * scale_factor)
99
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
100
+ new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0))
101
+ left_padding = (target_width - new_width) // 2
102
+ top_padding = (target_height - new_height) // 2
103
+ new_img.paste(img, (left_padding, top_padding))
104
+ img_byte_array = BytesIO()
105
+ new_img.save(img_byte_array, format="JPEG")
106
+ img_byte_array.seek(0)
107
+ filename = f"{uuid.uuid4().hex}.jpg"
108
+ s3_client.upload_fileobj(img_byte_array, BUCKET_NAME, model_upload_dir + filename, ExtraArgs={'ContentType': 'image/jpeg'})
109
+ model_url = s3_client.generate_presigned_url(
110
+ 'get_object',
111
+ Params={'Bucket': BUCKET_NAME, 'Key': model_upload_dir + filename},
112
+ ExpiresIn=3600
113
+ )
114
+ return model_url
115
+
116
+
117
+ def display_image(image, image_url):
118
+ garment_file_path = process_cloth_image(image_url)
119
+ model_file_path = process_model_image(image)
120
+ print(garment_file_path, model_file_path)
121
+ payload = {
122
+ "human_image_url": model_file_path,
123
+ "garment_image_url": garment_file_path
124
+ }
125
+ print(payload)
126
+ results = []
127
+ try:
128
+ print("Entering Modal block")
129
+ response = requests.post(MODAL_INFERENCE_ENDPOINT_URL, json=payload)
130
+ if response.status_code == 200:
131
+ result_data = response.json()
132
+ for key in result_data.keys():
133
+ if isinstance(result_data[key], str) and result_data[key].startswith('http'):
134
+ print(result_data[key])
135
+ presigned = generate_presigned_url(result_data[key])
136
+ response = requests.get(presigned)
137
+ img = Image.open(BytesIO(response.content))
138
+ img_resized = img.resize((512, 682))
139
+ return img_resized
140
+ else:
141
+ results.append({"error": f"Failed to process the garment image. Status Code: {response.status_code}"})
142
+ except requests.exceptions.RequestException as e:
143
+ results.append({"error": f"Request failed for the garment image. Error: {str(e)}"})
144
+ return ""
145
+
146
+ def generate_presigned_url(object_url):
147
+ parsed_url = urlparse(object_url)
148
+ path_parts = parsed_url.path.lstrip('/').split('/', 1)
149
+ object_key = path_parts[1] if len(path_parts) > 1 else ''
150
+ print(f"Extracted Object Key: {object_key}")
151
+ try:
152
+ presigned_url = s3_client.generate_presigned_url(
153
+ 'get_object',
154
+ Params={
155
+ 'Bucket': BUCKET_NAME,
156
+ 'Key': object_key
157
+ },
158
+ ExpiresIn=3600
159
+ )
160
+ return presigned_url
161
+ except Exception as e:
162
+ print(f"Error generating pre-signed URL: {e}")
163
+ return None
164
+
165
+
166
+ with gr.Blocks() as demo:
167
+ with gr.Row():
168
+ image_url_input = gr.Textbox(label="Image URL", placeholder="Enter image URL here")
169
+ input_garment_image = gr.Image(label="Garment Image", type="pil", width="384px", height = "512px")
170
+ uploaded_image = gr.Image(label="Upload or Capture Image", type="pil", width="384px", height="512px")
171
+ output_display = gr.Image(label="Displayed Image or URL Result", width="384px", height="512px")
172
+
173
+ image_url_input.change(
174
+ load_image_from_url,
175
+ inputs=image_url_input,
176
+ outputs=input_garment_image
177
+ )
178
+ submit_btn = gr.Button("Submit")
179
+ submit_btn.click(
180
+ display_image,
181
+ inputs=[uploaded_image, image_url_input],
182
+ outputs=output_display
183
+ )
184
+
185
+
186
+ demo.launch()