|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import mediapipe as mp |
|
|
from io import BytesIO |
|
|
import time |
|
|
import requests |
|
|
import base64 |
|
|
from simple_salesforce import Salesforce |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
mp_pose = mp.solutions.pose |
|
|
pose = mp_pose.Pose(static_image_mode=True, model_complexity=2) |
|
|
|
|
|
|
|
|
SF_CONFIG = { |
|
|
"username": os.environ.get("SF_USERNAME", "virtualtryonladies493@agentforce.com"), |
|
|
"password": os.environ.get("SF_PASSWORD", "virtual@tryon1"), |
|
|
"security_token": os.environ.get("SF_SECURITY_TOKEN", "NAE2Kn7cySp4NuJwjiASS6j9f"), |
|
|
"domain": os.environ.get("SF_DOMAIN", "login") |
|
|
} |
|
|
|
|
|
class SalesforceConnector: |
|
|
def __init__(self): |
|
|
self.sf = None |
|
|
self.connect() |
|
|
|
|
|
def connect(self): |
|
|
try: |
|
|
self.sf = Salesforce( |
|
|
username=SF_CONFIG["username"], |
|
|
password=SF_CONFIG["password"], |
|
|
security_token=SF_CONFIG["security_token"], |
|
|
domain=SF_CONFIG["domain"] |
|
|
) |
|
|
logger.info("Successfully connected to Salesforce") |
|
|
except Exception as e: |
|
|
logger.error(f"Salesforce connection failed: {str(e)}") |
|
|
self.sf = None |
|
|
|
|
|
def is_connected(self): |
|
|
return self.sf is not None |
|
|
|
|
|
def query_dresses(self, dress_type=None): |
|
|
if not self.is_connected(): |
|
|
self.connect() |
|
|
if not self.is_connected(): |
|
|
logger.error("Failed to reconnect to Salesforce") |
|
|
return {} |
|
|
|
|
|
try: |
|
|
query = "SELECT Id, Name, Image_URL__c, Dress_Type__c FROM Dress__c" |
|
|
if dress_type in ['Casual', 'Formal']: |
|
|
query += f" WHERE Dress_Type__c = '{dress_type}'" |
|
|
query += " LIMIT 100" |
|
|
result = self.sf.query(query) |
|
|
logger.info(f"Fetched {len(result.get('records', []))} records for {dress_type or 'all'} dresses") |
|
|
return result.get("records", []) |
|
|
except Exception as e: |
|
|
logger.error(f"Salesforce query failed: {str(e)}") |
|
|
return [] |
|
|
|
|
|
sf_connector = SalesforceConnector() |
|
|
|
|
|
def fetch_dresses(dress_type=None): |
|
|
"""Fetch dresses with automatic reconnection, filtered by dress type""" |
|
|
dresses = {} |
|
|
records = sf_connector.query_dresses(dress_type) |
|
|
|
|
|
for item in records: |
|
|
logger.info(f"Processing record: {item.get('Name')} with Image_URL__c: {item.get('Image_URL__c')}") |
|
|
if not item.get('Image_URL__c'): |
|
|
logger.warning(f"Skipping dress {item.get('Name')} due to missing Image_URL__c") |
|
|
continue |
|
|
|
|
|
try: |
|
|
img_response = requests.get(item['Image_URL__c'], timeout=10) |
|
|
img_response.raise_for_status() |
|
|
img = Image.open(BytesIO(img_response.content)).convert("RGBA") |
|
|
dresses[item['Id']] = { |
|
|
'name': item['Name'], |
|
|
'image': img, |
|
|
'thumbnail': generate_thumbnail(img), |
|
|
'type': item.get('Dress_Type__c', 'Unknown') |
|
|
} |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load dress {item.get('Name')}: {str(e)}") |
|
|
|
|
|
logger.info(f"Cached {len(dresses)} {dress_type or 'all'} dresses") |
|
|
return dresses |
|
|
|
|
|
def generate_thumbnail(img, size=(150, 200)): |
|
|
"""Generate base64 thumbnail for UI""" |
|
|
img = img.copy() |
|
|
img.thumbnail(size) |
|
|
buffered = BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" |
|
|
|
|
|
|
|
|
def get_body_landmarks(image): |
|
|
"""Detect body landmarks using MediaPipe""" |
|
|
try: |
|
|
image_rgb = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) |
|
|
results = pose.process(image_rgb) |
|
|
if results.pose_landmarks: |
|
|
landmarks = [ |
|
|
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_SHOULDER].x, |
|
|
results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_SHOULDER].y), |
|
|
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_SHOULDER].x, |
|
|
results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_SHOULDER].y), |
|
|
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_HIP].x, |
|
|
results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_HIP].y), |
|
|
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_HIP].x, |
|
|
results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_HIP].y) |
|
|
] |
|
|
return landmarks |
|
|
return [] |
|
|
except Exception as e: |
|
|
logger.error(f"Body landmark detection failed: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def warp_and_overlay(user_img, dress_img, body_pts): |
|
|
"""Overlay dress on user image using body landmarks""" |
|
|
try: |
|
|
if not body_pts: |
|
|
return user_img |
|
|
|
|
|
user_height, user_width = user_img.shape[:2] |
|
|
dress_height, dress_width = dress_img.shape[:2] |
|
|
shoulder_width = abs(body_pts[0][0] - body_pts[1][0]) * user_width |
|
|
scale_factor = shoulder_width / dress_width * 0.8 |
|
|
new_width = int(dress_width * scale_factor) |
|
|
new_height = int(dress_height * scale_factor) |
|
|
resized_dress = cv2.resize(dress_img, (new_width, new_height), interpolation=cv2.INTER_AREA) |
|
|
|
|
|
x = int((body_pts[0][0] + body_pts[1][0]) * user_width / 2 - new_width / 2) |
|
|
y = int((body_pts[0][1] + body_pts[1][1]) * user_height / 2) |
|
|
x = max(0, min(x, user_width - new_width)) |
|
|
y = max(0, min(y, user_height - new_height)) |
|
|
|
|
|
if resized_dress.shape[2] == 4: |
|
|
dress_mask = resized_dress[:, :, 3] / 255.0 |
|
|
dress_mask = np.dstack([dress_mask, dress_mask, dress_mask]) |
|
|
dress_rgb = resized_dress[:, :, :3] |
|
|
else: |
|
|
dress_mask = np.ones(resized_dress.shape[:2]) |
|
|
dress_rgb = resized_dress |
|
|
|
|
|
for c in range(3): |
|
|
user_img[y:y+new_height, x:x+new_width, c] = ( |
|
|
user_img[y:y+new_height, x:x+new_width, c] * (1 - dress_mask) + |
|
|
dress_rgb[:, :, c] * dress_mask |
|
|
) |
|
|
|
|
|
return user_img |
|
|
except Exception as e: |
|
|
logger.error(f"Dress overlay failed: {str(e)}") |
|
|
return user_img |
|
|
|
|
|
def process_try_on(user_image, dress_id, dress_type): |
|
|
"""Main processing with enhanced error handling""" |
|
|
if user_image is None: |
|
|
return None, "Please upload an image" |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"Processing try-on with dress_id: {dress_id}, dress_type: {dress_type}") |
|
|
|
|
|
try: |
|
|
|
|
|
user_img = np.array(user_image) |
|
|
|
|
|
|
|
|
body_pts = get_body_landmarks(user_image) |
|
|
if not body_pts: |
|
|
return user_image, "Stand facing camera with arms slightly away" |
|
|
|
|
|
|
|
|
dresses = fetch_dresses(dress_type) |
|
|
if not dresses: |
|
|
return user_image, f"No {dress_type} dresses available - check Salesforce connection" |
|
|
|
|
|
if dress_id is None or dress_id not in dresses: |
|
|
logger.error(f"Invalid or None dress_id: {dress_id}, available: {list(dresses.keys())}") |
|
|
return user_image, "Please select a valid dress from the dropdown" |
|
|
|
|
|
|
|
|
dress_img = np.array(dresses[dress_id]['image']) |
|
|
result = warp_and_overlay(user_img.copy(), dress_img, body_pts) |
|
|
|
|
|
return Image.fromarray(result), f"Done in {time.time()-start_time:.2f}s" |
|
|
except Exception as e: |
|
|
logger.error(f"Try-on failed: {str(e)}") |
|
|
return user_image, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Virtual Try-On", css=".thumbnail { height: 100px !important }") as demo: |
|
|
gr.Markdown("# ๐ Virtual Try-On (Salesforce Connected)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(sources=["upload", "webcam"], type="pil", label="Your Photo") |
|
|
dress_type = gr.Dropdown( |
|
|
label="Select Dress Type", |
|
|
choices=["Casual", "Formal"], |
|
|
value="Casual", |
|
|
interactive=True |
|
|
) |
|
|
with gr.Row(): |
|
|
refresh_btn = gr.Button("๐ Refresh Dresses") |
|
|
try_btn = gr.Button("๐ Try On Dress", variant="primary") |
|
|
dress_dropdown = gr.Dropdown(label="Select Dress", interactive=True, value=None) |
|
|
connection_status = gr.Textbox( |
|
|
label="Salesforce Status", |
|
|
value="Connected" if sf_connector.is_connected() else "Disconnected" |
|
|
) |
|
|
refresh_manual = gr.Button("๐ Manual Refresh") |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(label="Your Virtual Try-On", interactive=False) |
|
|
status = gr.Textbox(label="Status") |
|
|
|
|
|
def update_dress_dropdown(dress_type): |
|
|
dresses = fetch_dresses(dress_type) |
|
|
choices = [(f"{d['name']} ({d['type']})", id) for id, d in dresses.items()] |
|
|
status = "Connected" if sf_connector.is_connected() else "Disconnected" |
|
|
logger.info(f"Updating dropdown with {len(choices)} choices for {dress_type}, choices: {choices}") |
|
|
return ( |
|
|
{"choices": choices, "value": None, "label": f"Select Dress ({len(choices)} {dress_type} dresses)"}, |
|
|
f"Salesforce: {status} | {len(choices)} dresses loaded" |
|
|
) |
|
|
|
|
|
refresh_btn.click( |
|
|
fn=update_dress_dropdown, |
|
|
inputs=dress_type, |
|
|
outputs=[dress_dropdown, connection_status] |
|
|
) |
|
|
|
|
|
dress_type.change( |
|
|
fn=update_dress_dropdown, |
|
|
inputs=dress_type, |
|
|
outputs=[dress_dropdown, connection_status] |
|
|
) |
|
|
|
|
|
refresh_manual.click( |
|
|
fn=update_dress_dropdown, |
|
|
inputs=dress_type, |
|
|
outputs=[dress_dropdown, connection_status] |
|
|
) |
|
|
|
|
|
try_btn.click( |
|
|
fn=process_try_on, |
|
|
inputs=[input_image, dress_dropdown, dress_type], |
|
|
outputs=[output_image, status] |
|
|
) |
|
|
|
|
|
demo.load( |
|
|
fn=update_dress_dropdown, |
|
|
inputs=dress_type, |
|
|
outputs=[dress_dropdown, connection_status] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |