PrashanthB461's picture
Update app.py
9429334 verified
import os
import cv2
import numpy as np
from PIL import Image
import gradio as gr
import mediapipe as mp
from io import BytesIO
import time
import requests
import base64
from simple_salesforce import Salesforce
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize MediaPipe
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=True, model_complexity=2)
# Salesforce Configuration (use environment variables in production)
SF_CONFIG = {
"username": os.environ.get("SF_USERNAME", "virtualtryonladies493@agentforce.com"),
"password": os.environ.get("SF_PASSWORD", "virtual@tryon1"),
"security_token": os.environ.get("SF_SECURITY_TOKEN", "NAE2Kn7cySp4NuJwjiASS6j9f"),
"domain": os.environ.get("SF_DOMAIN", "login") # For production: "login" or "test"
}
class SalesforceConnector:
def __init__(self):
self.sf = None
self.connect()
def connect(self):
try:
self.sf = Salesforce(
username=SF_CONFIG["username"],
password=SF_CONFIG["password"],
security_token=SF_CONFIG["security_token"],
domain=SF_CONFIG["domain"]
)
logger.info("Successfully connected to Salesforce")
except Exception as e:
logger.error(f"Salesforce connection failed: {str(e)}")
self.sf = None
def is_connected(self):
return self.sf is not None
def query_dresses(self, dress_type=None):
if not self.is_connected():
self.connect()
if not self.is_connected():
logger.error("Failed to reconnect to Salesforce")
return {}
try:
query = "SELECT Id, Name, Image_URL__c, Dress_Type__c FROM Dress__c"
if dress_type in ['Casual', 'Formal']: # Prevent SOQL injection
query += f" WHERE Dress_Type__c = '{dress_type}'"
query += " LIMIT 100" # Increased limit to fetch more records
result = self.sf.query(query)
logger.info(f"Fetched {len(result.get('records', []))} records for {dress_type or 'all'} dresses")
return result.get("records", [])
except Exception as e:
logger.error(f"Salesforce query failed: {str(e)}")
return []
sf_connector = SalesforceConnector()
def fetch_dresses(dress_type=None):
"""Fetch dresses with automatic reconnection, filtered by dress type"""
dresses = {}
records = sf_connector.query_dresses(dress_type)
for item in records:
logger.info(f"Processing record: {item.get('Name')} with Image_URL__c: {item.get('Image_URL__c')}")
if not item.get('Image_URL__c'):
logger.warning(f"Skipping dress {item.get('Name')} due to missing Image_URL__c")
continue
try:
img_response = requests.get(item['Image_URL__c'], timeout=10)
img_response.raise_for_status()
img = Image.open(BytesIO(img_response.content)).convert("RGBA")
dresses[item['Id']] = {
'name': item['Name'],
'image': img,
'thumbnail': generate_thumbnail(img),
'type': item.get('Dress_Type__c', 'Unknown')
}
except Exception as e:
logger.warning(f"Failed to load dress {item.get('Name')}: {str(e)}")
logger.info(f"Cached {len(dresses)} {dress_type or 'all'} dresses")
return dresses
def generate_thumbnail(img, size=(150, 200)):
"""Generate base64 thumbnail for UI"""
img = img.copy()
img.thumbnail(size)
buffered = BytesIO()
img.save(buffered, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
# Placeholder for body landmark detection and dress overlay
def get_body_landmarks(image):
"""Detect body landmarks using MediaPipe"""
try:
image_rgb = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if results.pose_landmarks:
landmarks = [
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_SHOULDER].x,
results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_SHOULDER].y),
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_SHOULDER].x,
results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_SHOULDER].y),
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_HIP].x,
results.pose_landmarks.landmark[mp_pose.PoseLandmark.LEFT_HIP].y),
(results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_HIP].x,
results.pose_landmarks.landmark[mp_pose.PoseLandmark.RIGHT_HIP].y)
]
return landmarks
return []
except Exception as e:
logger.error(f"Body landmark detection failed: {str(e)}")
return []
def warp_and_overlay(user_img, dress_img, body_pts):
"""Overlay dress on user image using body landmarks"""
try:
if not body_pts:
return user_img
user_height, user_width = user_img.shape[:2]
dress_height, dress_width = dress_img.shape[:2]
shoulder_width = abs(body_pts[0][0] - body_pts[1][0]) * user_width
scale_factor = shoulder_width / dress_width * 0.8
new_width = int(dress_width * scale_factor)
new_height = int(dress_height * scale_factor)
resized_dress = cv2.resize(dress_img, (new_width, new_height), interpolation=cv2.INTER_AREA)
x = int((body_pts[0][0] + body_pts[1][0]) * user_width / 2 - new_width / 2)
y = int((body_pts[0][1] + body_pts[1][1]) * user_height / 2)
x = max(0, min(x, user_width - new_width))
y = max(0, min(y, user_height - new_height))
if resized_dress.shape[2] == 4:
dress_mask = resized_dress[:, :, 3] / 255.0
dress_mask = np.dstack([dress_mask, dress_mask, dress_mask])
dress_rgb = resized_dress[:, :, :3]
else:
dress_mask = np.ones(resized_dress.shape[:2])
dress_rgb = resized_dress
for c in range(3):
user_img[y:y+new_height, x:x+new_width, c] = (
user_img[y:y+new_height, x:x+new_width, c] * (1 - dress_mask) +
dress_rgb[:, :, c] * dress_mask
)
return user_img
except Exception as e:
logger.error(f"Dress overlay failed: {str(e)}")
return user_img
def process_try_on(user_image, dress_id, dress_type):
"""Main processing with enhanced error handling"""
if user_image is None:
return None, "Please upload an image"
start_time = time.time()
logger.info(f"Processing try-on with dress_id: {dress_id}, dress_type: {dress_type}")
try:
# Convert to numpy array
user_img = np.array(user_image)
# Get body landmarks
body_pts = get_body_landmarks(user_image)
if not body_pts:
return user_image, "Stand facing camera with arms slightly away"
# Get dresses
dresses = fetch_dresses(dress_type)
if not dresses:
return user_image, f"No {dress_type} dresses available - check Salesforce connection"
if dress_id is None or dress_id not in dresses:
logger.error(f"Invalid or None dress_id: {dress_id}, available: {list(dresses.keys())}")
return user_image, "Please select a valid dress from the dropdown"
# Process try-on
dress_img = np.array(dresses[dress_id]['image'])
result = warp_and_overlay(user_img.copy(), dress_img, body_pts)
return Image.fromarray(result), f"Done in {time.time()-start_time:.2f}s"
except Exception as e:
logger.error(f"Try-on failed: {str(e)}")
return user_image, f"Error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="Virtual Try-On", css=".thumbnail { height: 100px !important }") as demo:
gr.Markdown("# ๐Ÿ‘— Virtual Try-On (Salesforce Connected)")
with gr.Row():
with gr.Column():
input_image = gr.Image(sources=["upload", "webcam"], type="pil", label="Your Photo")
dress_type = gr.Dropdown(
label="Select Dress Type",
choices=["Casual", "Formal"],
value="Casual",
interactive=True
)
with gr.Row():
refresh_btn = gr.Button("๐Ÿ”„ Refresh Dresses")
try_btn = gr.Button("๐Ÿ‘— Try On Dress", variant="primary")
dress_dropdown = gr.Dropdown(label="Select Dress", interactive=True, value=None)
connection_status = gr.Textbox(
label="Salesforce Status",
value="Connected" if sf_connector.is_connected() else "Disconnected"
)
refresh_manual = gr.Button("๐Ÿ”„ Manual Refresh") # Added manual refresh button
with gr.Column():
output_image = gr.Image(label="Your Virtual Try-On", interactive=False)
status = gr.Textbox(label="Status")
def update_dress_dropdown(dress_type):
dresses = fetch_dresses(dress_type)
choices = [(f"{d['name']} ({d['type']})", id) for id, d in dresses.items()]
status = "Connected" if sf_connector.is_connected() else "Disconnected"
logger.info(f"Updating dropdown with {len(choices)} choices for {dress_type}, choices: {choices}")
return (
{"choices": choices, "value": None, "label": f"Select Dress ({len(choices)} {dress_type} dresses)"},
f"Salesforce: {status} | {len(choices)} dresses loaded"
)
refresh_btn.click(
fn=update_dress_dropdown,
inputs=dress_type,
outputs=[dress_dropdown, connection_status]
)
dress_type.change(
fn=update_dress_dropdown,
inputs=dress_type,
outputs=[dress_dropdown, connection_status]
)
refresh_manual.click(
fn=update_dress_dropdown,
inputs=dress_type,
outputs=[dress_dropdown, connection_status]
)
try_btn.click(
fn=process_try_on,
inputs=[input_image, dress_dropdown, dress_type],
outputs=[output_image, status]
)
demo.load(
fn=update_dress_dropdown,
inputs=dress_type,
outputs=[dress_dropdown, connection_status]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)