| | import gradio as gr |
| | from dotenv import load_dotenv |
| | import requests |
| | from flask import Flask, jsonify, request, send_file |
| | from botocore.exceptions import ClientError |
| | from botocore.client import Config |
| | import boto3 |
| | from urllib.parse import urlparse |
| | import os |
| | from PIL import Image |
| | from io import BytesIO |
| | import uuid |
| |
|
| |
|
| | load_dotenv() |
| |
|
| | AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") |
| | AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") |
| | BUCKET_NAME = "tech-tailor" |
| | s3_client = boto3.client( |
| | "s3", |
| | region_name='ap-south-1', |
| | aws_access_key_id=AWS_ACCESS_KEY_ID, |
| | aws_secret_access_key=AWS_SECRET_ACCESS_KEY, |
| | config=Config(signature_version='s3v4') |
| | ) |
| | MODAL_INFERENCE_ENDPOINT_URL = os.getenv("MODAL_INFERENCE_ENDPOINT_URL") |
| | app = Flask(__name__) |
| |
|
| | GARM_SAVE_DIR = "garment_images" |
| | MODE_SAVE_DIR = "model_images" |
| |
|
| | garment_upload_dir = "gradio_demo_garment/" |
| | model_upload_dir = "gradio_demo_model/" |
| |
|
| | def load_image_from_url(image_url): |
| | try: |
| | response = requests.get(image_url) |
| | if "image" in response.headers["Content-Type"]: |
| | img = Image.open(BytesIO(response.content)) |
| | return img |
| | else: |
| | return None |
| | except Exception as e: |
| | print(f"Error loading image: {e}") |
| | return None |
| | |
| | def process_cloth_image(image_url): |
| | if image_url: |
| | try: |
| | response = requests.get(image_url) |
| | response.raise_for_status() |
| | img = Image.open(BytesIO(response.content)) |
| | img = img.convert("RGB") |
| | img_width, img_height = img.size |
| | target_width = 768 |
| | target_height = 1024 |
| | scale_width = target_width / img_width |
| | scale_height = target_height / img_height |
| | scale_factor = min(scale_width, scale_height) |
| | new_width = int(img_width * scale_factor) |
| | new_height = int(img_height * scale_factor) |
| | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| | new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
| | left_padding = (target_width - new_width) // 2 |
| | top_padding = (target_height - new_height) // 2 |
| | new_img.paste(img, (left_padding, top_padding)) |
| | img_byte_array = BytesIO() |
| | new_img.save(img_byte_array, format="JPEG") |
| | img_byte_array.seek(0) |
| | filename = f"{uuid.uuid4().hex}.jpg" |
| | s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = garment_upload_dir + filename, ContentType= 'image/jpeg') |
| | garment_url = s3_client.generate_presigned_url( |
| | 'get_object', |
| | Params={'Bucket': BUCKET_NAME, 'Key': garment_upload_dir + filename}, |
| | ExpiresIn=3600 |
| | ) |
| | return garment_url |
| | |
| | except requests.exceptions.RequestException as e: |
| | return f"Error downloading image: {e}" |
| | except Exception as e: |
| | return f"Error processing image: {e}" |
| | else: |
| | return "No image provided" |
| |
|
| | def process_model_image(image): |
| | img = image.convert("RGB") |
| | img_width, img_height = img.size |
| | target_width = 768 |
| | target_height = 1024 |
| | scale_width = target_width / img_width |
| | scale_height = target_height / img_height |
| | scale_factor = min(scale_width, scale_height) |
| | new_width = int(img_width * scale_factor) |
| | new_height = int(img_height * scale_factor) |
| | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| | new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
| | left_padding = (target_width - new_width) // 2 |
| | top_padding = (target_height - new_height) // 2 |
| | new_img.paste(img, (left_padding, top_padding)) |
| | img_byte_array = BytesIO() |
| | new_img.save(img_byte_array, format="JPEG") |
| | img_byte_array.seek(0) |
| | filename = f"{uuid.uuid4().hex}.jpg" |
| | s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = model_upload_dir + filename, ContentType = 'image/jpeg') |
| | model_url = s3_client.generate_presigned_url( |
| | 'get_object', |
| | Params={'Bucket': BUCKET_NAME, 'Key': model_upload_dir + filename}, |
| | ExpiresIn=3600 |
| | ) |
| | return model_url |
| |
|
| | |
| | def display_image(image, image_url): |
| | garment_file_path = process_cloth_image(image_url) |
| | model_file_path = process_model_image(image) |
| | print(garment_file_path, model_file_path) |
| | payload = { |
| | "human_image_url": model_file_path, |
| | "garment_image_url": garment_file_path |
| | } |
| | print(payload) |
| | results = [] |
| | try: |
| | print("Entering Modal block") |
| | response = requests.post(MODAL_INFERENCE_ENDPOINT_URL, json=payload) |
| | if response.status_code == 200: |
| | result_data = response.json() |
| | url = result_data["url"] |
| | response = requests.get(url) |
| | img = Image.open(BytesIO(response.content)) |
| | img_resized = img.resize((512, 682)) |
| | return img_resized |
| | else: |
| | results.append({"error": f"Failed to process the garment image. Status Code: {response.status_code}"}) |
| | except requests.exceptions.RequestException as e: |
| | results.append({"error": f"Request failed for the garment image. Error: {str(e)}"}) |
| | return "" |
| |
|
| | def generate_presigned_url(object_url): |
| | parsed_url = urlparse(object_url) |
| | path_parts = parsed_url.path.lstrip('/').split('/', 1) |
| | object_key = path_parts[1] if len(path_parts) > 1 else '' |
| | print(f"Extracted Object Key: {object_key}") |
| | try: |
| | presigned_url = s3_client.generate_presigned_url( |
| | 'get_object', |
| | Params={ |
| | 'Bucket': BUCKET_NAME, |
| | 'Key': object_key |
| | }, |
| | ExpiresIn=3600 |
| | ) |
| | return presigned_url |
| | except Exception as e: |
| | print(f"Error generating pre-signed URL: {e}") |
| | return None |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## Virtual Try-On") |
| |
|
| | |
| | with gr.Row(elem_id="responsive-container"): |
| | image_url_input = gr.Textbox( |
| | label="Image URL", |
| | placeholder="Enter image URL here", |
| | interactive=True, |
| | elem_id="url-input" |
| | ) |
| | input_garment_image = gr.Image( |
| | label="Garment Image", |
| | type="pil", |
| | interactive=True, |
| | elem_id="garment-image" |
| | ) |
| | uploaded_image = gr.Image( |
| | label="Upload or Capture Image", |
| | type="pil", |
| | interactive=True, |
| | elem_id="uploaded-image" |
| | ) |
| | output_display = gr.Image( |
| | label="Displayed Image or URL Result", |
| | elem_id="output-image" |
| | ) |
| |
|
| | image_url_input.change( |
| | load_image_from_url, |
| | inputs=image_url_input, |
| | outputs=input_garment_image |
| | ) |
| |
|
| | submit_btn = gr.Button("Submit", elem_id="submit-btn") |
| | submit_btn.click( |
| | display_image, |
| | inputs=[uploaded_image, image_url_input], |
| | outputs=output_display |
| | ) |
| |
|
| |
|
| | |
| | custom_css = """ |
| | /* Default layout for PC screens */ |
| | #responsive-container { |
| | display: flex; |
| | flex-direction: row; |
| | gap: 20px; |
| | justify-content: center; |
| | align-items: center; |
| | } |
| | #responsive-container .gr-box { |
| | flex: 1; |
| | max-width: 384px; |
| | } |
| | |
| | /* Responsive layout for mobile screens */ |
| | @media (max-width: 768px) { |
| | #responsive-container { |
| | flex-wrap: wrap; |
| | flex-direction: column; |
| | } |
| | #responsive-container .gr-box { |
| | flex: 1 1 100%; |
| | max-width: 100%; |
| | } |
| | } |
| | """ |
| |
|
| | demo.launch(share=True) |
| |
|