File size: 6,709 Bytes
f52c1ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59f06a
f52c1ab
 
c59f06a
3e4548a
 
1bc2682
f52c1ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f0fe6c
f52c1ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f0fe6c
f52c1ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59f06a
 
 
 
 
f52c1ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebcb30e
 
 
 
 
 
c0c7e30
 
 
 
 
ebcb30e
 
 
 
 
 
f52c1ab
8a5aecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import gradio as gr
from dotenv import load_dotenv
import requests
from flask import Flask, jsonify, request, send_file
from botocore.exceptions import ClientError
from botocore.client import Config
import boto3
from urllib.parse import urlparse
import os
from PIL import Image
from io import BytesIO
import uuid


load_dotenv()

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
BUCKET_NAME = "tech-tailor"
s3_client = boto3.client(
    "s3",
    region_name='ap-south-1',
    aws_access_key_id=AWS_ACCESS_KEY_ID, 
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    config=Config(signature_version='s3v4')
)
MODAL_INFERENCE_ENDPOINT_URL = os.getenv("MODAL_INFERENCE_ENDPOINT_URL")
app = Flask(__name__)

GARM_SAVE_DIR = "garment_images"
MODE_SAVE_DIR = "model_images"

garment_upload_dir = "gradio_demo_garment/"
model_upload_dir = "gradio_demo_model/"

def load_image_from_url(image_url):
    try:
        response = requests.get(image_url)
        if "image" in response.headers["Content-Type"]:
            img = Image.open(BytesIO(response.content))
            return img
        else:
            return None
    except Exception as e:
        print(f"Error loading image: {e}")
        return None
    
def process_cloth_image(image_url):
    if image_url:
        try:
            response = requests.get(image_url)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content))
            img = img.convert("RGB")
            img_width, img_height = img.size
            target_width = 768
            target_height = 1024
            scale_width = target_width / img_width
            scale_height = target_height / img_height
            scale_factor = min(scale_width, scale_height)
            new_width = int(img_width * scale_factor)
            new_height = int(img_height * scale_factor)
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0))
            left_padding = (target_width - new_width) // 2
            top_padding = (target_height - new_height) // 2
            new_img.paste(img, (left_padding, top_padding))
            img_byte_array = BytesIO()
            new_img.save(img_byte_array, format="JPEG")
            img_byte_array.seek(0) 
            filename = f"{uuid.uuid4().hex}.jpg"
            s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = garment_upload_dir + filename, ContentType= 'image/jpeg')
            garment_url = s3_client.generate_presigned_url(
                'get_object', 
                Params={'Bucket': BUCKET_NAME, 'Key': garment_upload_dir + filename}, 
                ExpiresIn=3600
            )
            return garment_url
        
        except requests.exceptions.RequestException as e:
            return f"Error downloading image: {e}"
        except Exception as e:
            return f"Error processing image: {e}"
    else:
        return "No image provided"

def process_model_image(image):
    img = image.convert("RGB")
    img_width, img_height = img.size
    target_width = 768
    target_height = 1024
    scale_width = target_width / img_width
    scale_height = target_height / img_height
    scale_factor = min(scale_width, scale_height)
    new_width = int(img_width * scale_factor)
    new_height = int(img_height * scale_factor)
    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0))
    left_padding = (target_width - new_width) // 2
    top_padding = (target_height - new_height) // 2
    new_img.paste(img, (left_padding, top_padding))
    img_byte_array = BytesIO()
    new_img.save(img_byte_array, format="JPEG")
    img_byte_array.seek(0) 
    filename = f"{uuid.uuid4().hex}.jpg"
    s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = model_upload_dir + filename, ContentType = 'image/jpeg')
    model_url = s3_client.generate_presigned_url(
        'get_object', 
        Params={'Bucket': BUCKET_NAME, 'Key': model_upload_dir + filename}, 
        ExpiresIn=3600
    )
    return model_url

    
def display_image(image, image_url):
    garment_file_path = process_cloth_image(image_url)
    model_file_path = process_model_image(image)
    print(garment_file_path, model_file_path)
    payload = {
        "human_image_url": model_file_path,
        "garment_image_url": garment_file_path
    }
    print(payload)
    results = []
    try:
        print("Entering Modal block")
        response = requests.post(MODAL_INFERENCE_ENDPOINT_URL, json=payload)
        if response.status_code == 200:
            result_data = response.json()
            url = result_data["url"]
            response = requests.get(url)
            img = Image.open(BytesIO(response.content))
            img_resized = img.resize((512, 682))
            return img_resized
        else:
            results.append({"error": f"Failed to process the garment image. Status Code: {response.status_code}"})
    except requests.exceptions.RequestException as e:
        results.append({"error": f"Request failed for the garment image. Error: {str(e)}"})
    return ""        

def generate_presigned_url(object_url):
    parsed_url = urlparse(object_url)
    path_parts = parsed_url.path.lstrip('/').split('/', 1)
    object_key = path_parts[1] if len(path_parts) > 1 else ''
    print(f"Extracted Object Key: {object_key}")
    try:
        presigned_url = s3_client.generate_presigned_url(
            'get_object',
            Params={
                'Bucket': BUCKET_NAME,
                'Key': object_key
            },
            ExpiresIn=3600
        )
        return presigned_url
    except Exception as e:
        print(f"Error generating pre-signed URL: {e}")
        return None

            
with gr.Blocks() as demo:
    with gr.Row():
        image_url_input = gr.Textbox(label="Image URL", placeholder="Enter image URL here")
        input_garment_image = gr.Image(label="Garment Image", type="pil", width="384px", height = "512px")
        uploaded_image = gr.Image(label="Upload or Capture Image", type="pil", width="384px", height="512px")
        output_display = gr.Image(label="Displayed Image or URL Result", width="384px", height="512px")
    
    image_url_input.change(
        load_image_from_url, 
        inputs=image_url_input, 
        outputs=input_garment_image
    )
    submit_btn = gr.Button("Submit")
    submit_btn.click(
        display_image, 
        inputs=[uploaded_image, image_url_input], 
        outputs=output_display
    )

demo.launch(share=True)