from flask import Flask, request, jsonify
from flask_cors import CORS
import threading
import time
import traceback
import cv2, mediapipe as mp, numpy as np, base64, io
from PIL import Image
app = Flask(__name__)
CORS(app)
pose = None
mp_drawing = None
mp_pose = mp.solutions.pose
print("🚀 Flask app starting...")
@app.route("/", methods=["GET"])
def home():
return "
✅ Virtual Try-On API is Running (Health OK)
", 200
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"}), 200
def init_mediapipe():
"""Load MediaPipe Pose in a background thread."""
global pose, mp_drawing
try:
print("🧠 Initializing MediaPipe Pose (CPU)...")
start = time.time()
pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5)
mp_drawing = mp.solutions.drawing_utils
print("✅ MediaPipe initialized in", round(time.time() - start, 2), "seconds")
except Exception as e:
print("❌ MediaPipe init error:", e)
traceback.print_exc()
# Start loading MediaPipe asynchronously
threading.Thread(target=init_mediapipe, daemon=True).start()
def overlay_dress(frame, dress, landmarks):
if landmarks is None or dress is None:
return frame
h, w, _ = frame.shape
def to_pixel(lm):
return int(lm.x * w), int(lm.y * h)
try:
left_shoulder = to_pixel(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value])
right_shoulder = to_pixel(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER.value])
left_hip = to_pixel(landmarks[mp_pose.PoseLandmark.LEFT_HIP.value])
right_hip = to_pixel(landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value])
# calculate region for dress
dress_width = int(np.linalg.norm(np.array(left_shoulder) - np.array(right_shoulder)) * 1.8)
top_y = min(left_shoulder[1], right_shoulder[1])
bottom_y = max(left_hip[1], right_hip[1])
dress_height = int((bottom_y - top_y) * 1.2)
# clamp values within frame bounds
center_x = (left_shoulder[0] + right_shoulder[0]) // 2
x1 = max(center_x - dress_width // 2, 0)
y1 = max(top_y - 30, 0)
x2 = min(x1 + dress_width, w)
y2 = min(y1 + dress_height, h)
# ✅ guard against invalid or tiny dimensions
if x2 <= x1 or y2 <= y1:
print("⚠️ Skipping overlay — invalid bounding box.")
return frame
# ✅ guard against zero-size dress
if dress_width <= 0 or dress_height <= 0:
print("⚠️ Skipping overlay — invalid dress size.")
return frame
dress_resized = cv2.resize(dress, (x2 - x1, y2 - y1), interpolation=cv2.INTER_AREA)
# alpha blend if transparent
if dress_resized.shape[2] == 4:
alpha_s = dress_resized[:, :, 3] / 255.0
alpha_l = 1.0 - alpha_s
for c in range(3):
frame[y1:y2, x1:x2, c] = (
alpha_s * dress_resized[:, :, c] +
alpha_l * frame[y1:y2, x1:x2, c]
)
else:
# no transparency → just paste
frame[y1:y2, x1:x2] = cv2.addWeighted(
frame[y1:y2, x1:x2], 0.5, dress_resized, 0.5, 0
)
except Exception as e:
print("⚠️ overlay_dress error:", e)
traceback.print_exc()
return frame
@app.route("/tryon", methods=["POST"])
def tryon():
try:
global pose
if pose is None:
return jsonify({"error": "Model still loading, please retry in 1–2 minutes"}), 503
data = request.json.get("image")
dress_data = request.json.get("dress")
if not data or not dress_data:
return jsonify({"error": "Missing image or dress data"}), 400
img_bytes = base64.b64decode(data.split(",")[1])
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
dress_bytes = base64.b64decode(dress_data.split(",")[1])
dress_img = cv2.imdecode(np.frombuffer(dress_bytes, np.uint8), cv2.IMREAD_UNCHANGED)
results = pose.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
landmarks = results.pose_landmarks.landmark if results.pose_landmarks else None
frame = overlay_dress(frame, dress_img, landmarks)
if results.pose_landmarks:
mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
_, buffer = cv2.imencode(".jpg", frame)
img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffer).decode()
return jsonify({"image": img_base64})
except Exception as e:
print("❌ /tryon error:", e)
traceback.print_exc()
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 7860)) # 👈 use 7860 (Hugging Face default)
app.run(host="0.0.0.0", port=port)