import gradio as gr from dotenv import load_dotenv import requests from flask import Flask, jsonify, request, send_file from botocore.exceptions import ClientError from botocore.client import Config import boto3 from urllib.parse import urlparse import os from PIL import Image from io import BytesIO import uuid load_dotenv() AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") BUCKET_NAME = "tech-tailor" s3_client = boto3.client( "s3", region_name='ap-south-1', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, config=Config(signature_version='s3v4') ) MODAL_INFERENCE_ENDPOINT_URL = os.getenv("MODAL_INFERENCE_ENDPOINT_URL") app = Flask(__name__) GARM_SAVE_DIR = "garment_images" MODE_SAVE_DIR = "model_images" garment_upload_dir = "gradio_demo_garment/" model_upload_dir = "gradio_demo_model/" def load_image_from_url(image_url): try: response = requests.get(image_url) if "image" in response.headers["Content-Type"]: img = Image.open(BytesIO(response.content)) return img else: return None except Exception as e: print(f"Error loading image: {e}") return None def process_cloth_image(image_url): if image_url: try: response = requests.get(image_url) response.raise_for_status() img = Image.open(BytesIO(response.content)) img = img.convert("RGB") img_width, img_height = img.size target_width = 768 target_height = 1024 scale_width = target_width / img_width scale_height = target_height / img_height scale_factor = min(scale_width, scale_height) new_width = int(img_width * scale_factor) new_height = int(img_height * scale_factor) img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0)) left_padding = (target_width - new_width) // 2 top_padding = (target_height - new_height) // 2 new_img.paste(img, (left_padding, top_padding)) img_byte_array = BytesIO() new_img.save(img_byte_array, format="JPEG") img_byte_array.seek(0) filename = f"{uuid.uuid4().hex}.jpg" s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = garment_upload_dir + filename, ContentType= 'image/jpeg') garment_url = s3_client.generate_presigned_url( 'get_object', Params={'Bucket': BUCKET_NAME, 'Key': garment_upload_dir + filename}, ExpiresIn=3600 ) return garment_url except requests.exceptions.RequestException as e: return f"Error downloading image: {e}" except Exception as e: return f"Error processing image: {e}" else: return "No image provided" def process_model_image(image): img = image.convert("RGB") img_width, img_height = img.size target_width = 768 target_height = 1024 scale_width = target_width / img_width scale_height = target_height / img_height scale_factor = min(scale_width, scale_height) new_width = int(img_width * scale_factor) new_height = int(img_height * scale_factor) img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0)) left_padding = (target_width - new_width) // 2 top_padding = (target_height - new_height) // 2 new_img.paste(img, (left_padding, top_padding)) img_byte_array = BytesIO() new_img.save(img_byte_array, format="JPEG") img_byte_array.seek(0) filename = f"{uuid.uuid4().hex}.jpg" s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = model_upload_dir + filename, ContentType = 'image/jpeg') model_url = s3_client.generate_presigned_url( 'get_object', Params={'Bucket': BUCKET_NAME, 'Key': model_upload_dir + filename}, ExpiresIn=3600 ) return model_url def display_image(image, image_url): garment_file_path = process_cloth_image(image_url) model_file_path = process_model_image(image) print(garment_file_path, model_file_path) payload = { "human_image_url": model_file_path, "garment_image_url": garment_file_path } print(payload) results = [] try: print("Entering Modal block") response = requests.post(MODAL_INFERENCE_ENDPOINT_URL, json=payload) if response.status_code == 200: result_data = response.json() url = result_data["url"] response = requests.get(url) img = Image.open(BytesIO(response.content)) img_resized = img.resize((512, 682)) return img_resized else: results.append({"error": f"Failed to process the garment image. Status Code: {response.status_code}"}) except requests.exceptions.RequestException as e: results.append({"error": f"Request failed for the garment image. Error: {str(e)}"}) return "" def generate_presigned_url(object_url): parsed_url = urlparse(object_url) path_parts = parsed_url.path.lstrip('/').split('/', 1) object_key = path_parts[1] if len(path_parts) > 1 else '' print(f"Extracted Object Key: {object_key}") try: presigned_url = s3_client.generate_presigned_url( 'get_object', Params={ 'Bucket': BUCKET_NAME, 'Key': object_key }, ExpiresIn=3600 ) return presigned_url except Exception as e: print(f"Error generating pre-signed URL: {e}") return None with gr.Blocks() as demo: with gr.Row(): image_url_input = gr.Textbox(label="Image URL", placeholder="Enter image URL here") input_garment_image = gr.Image(label="Garment Image", type="pil", width="384px", height = "512px") uploaded_image = gr.Image(label="Upload or Capture Image", type="pil", width="384px", height="512px") output_display = gr.Image(label="Displayed Image or URL Result", width="384px", height="512px") image_url_input.change( load_image_from_url, inputs=image_url_input, outputs=input_garment_image ) submit_btn = gr.Button("Submit") submit_btn.click( display_image, inputs=[uploaded_image, image_url_input], outputs=output_display ) demo.launch(share=True)