File size: 5,006 Bytes
0ac22d5
c8438bd
29d7659
 
c8438bd
29d7659
 
0ac22d5
 
29d7659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac22d5
29d7659
 
0ac22d5
c8438bd
0ac22d5
3006ad1
 
c8438bd
3006ad1
5033bed
3006ad1
 
 
 
0ac22d5
 
 
 
5033bed
3006ad1
c8438bd
29d7659
 
 
5033bed
3006ad1
c8438bd
 
29d7659
0ac22d5
 
5033bed
3006ad1
 
 
 
 
 
 
 
 
 
c8438bd
3006ad1
 
29d7659
5033bed
0ac22d5
 
c8438bd
3006ad1
 
c8438bd
3006ad1
 
 
 
 
 
 
 
 
 
0ac22d5
 
c8438bd
3006ad1
0ac22d5
 
bb5b4b9
29d7659
 
 
 
 
 
5033bed
 
 
bb5b4b9
 
 
 
01db935
bb5b4b9
 
01db935
bb5b4b9
 
01db935
bb5b4b9
 
 
01db935
bb5b4b9
 
01db935
bb5b4b9
 
29d7659
c8438bd
5033bed
0ac22d5
c8438bd
5033bed
5ba1447
6c1dc97
 
 
5ba1447
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from flask import Flask, request, jsonify
from flask_cors import CORS
import threading
import time
import traceback
import cv2, mediapipe as mp, numpy as np, base64, io
from PIL import Image

app = Flask(__name__)
CORS(app)

pose = None
mp_drawing = None
mp_pose = mp.solutions.pose

print("πŸš€ Flask app starting...")

@app.route("/", methods=["GET"])
def home():
    return "<h1>βœ… Virtual Try-On API is Running (Health OK)</h1>", 200

@app.route("/health", methods=["GET"])
def health():
    return jsonify({"status": "ok"}), 200


def init_mediapipe():
    """Load MediaPipe Pose in a background thread."""
    global pose, mp_drawing
    try:
        print("🧠 Initializing MediaPipe Pose (CPU)...")
        start = time.time()
        pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5)
        mp_drawing = mp.solutions.drawing_utils
        print("βœ… MediaPipe initialized in", round(time.time() - start, 2), "seconds")
    except Exception as e:
        print("❌ MediaPipe init error:", e)
        traceback.print_exc()


# Start loading MediaPipe asynchronously
threading.Thread(target=init_mediapipe, daemon=True).start()


def overlay_dress(frame, dress, landmarks):
    if landmarks is None or dress is None:
        return frame

    h, w, _ = frame.shape

    def to_pixel(lm):
        return int(lm.x * w), int(lm.y * h)

    try:
        left_shoulder = to_pixel(landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value])
        right_shoulder = to_pixel(landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER.value])
        left_hip = to_pixel(landmarks[mp_pose.PoseLandmark.LEFT_HIP.value])
        right_hip = to_pixel(landmarks[mp_pose.PoseLandmark.RIGHT_HIP.value])

        # calculate region for dress
        dress_width = int(np.linalg.norm(np.array(left_shoulder) - np.array(right_shoulder)) * 1.8)
        top_y = min(left_shoulder[1], right_shoulder[1])
        bottom_y = max(left_hip[1], right_hip[1])
        dress_height = int((bottom_y - top_y) * 1.2)

        # clamp values within frame bounds
        center_x = (left_shoulder[0] + right_shoulder[0]) // 2
        x1 = max(center_x - dress_width // 2, 0)
        y1 = max(top_y - 30, 0)
        x2 = min(x1 + dress_width, w)
        y2 = min(y1 + dress_height, h)

        # βœ… guard against invalid or tiny dimensions
        if x2 <= x1 or y2 <= y1:
            print("⚠️ Skipping overlay β€” invalid bounding box.")
            return frame

        # βœ… guard against zero-size dress
        if dress_width <= 0 or dress_height <= 0:
            print("⚠️ Skipping overlay β€” invalid dress size.")
            return frame

        dress_resized = cv2.resize(dress, (x2 - x1, y2 - y1), interpolation=cv2.INTER_AREA)

        # alpha blend if transparent
        if dress_resized.shape[2] == 4:
            alpha_s = dress_resized[:, :, 3] / 255.0
            alpha_l = 1.0 - alpha_s
            for c in range(3):
                frame[y1:y2, x1:x2, c] = (
                    alpha_s * dress_resized[:, :, c] +
                    alpha_l * frame[y1:y2, x1:x2, c]
                )
        else:
            # no transparency β†’ just paste
            frame[y1:y2, x1:x2] = cv2.addWeighted(
                frame[y1:y2, x1:x2], 0.5, dress_resized, 0.5, 0
            )

    except Exception as e:
        print("⚠️ overlay_dress error:", e)
        traceback.print_exc()

    return frame



@app.route("/tryon", methods=["POST"])
def tryon():
    try:
        global pose
        if pose is None:
            return jsonify({"error": "Model still loading, please retry in 1–2 minutes"}), 503

        data = request.json.get("image")
        dress_data = request.json.get("dress")

        if not data or not dress_data:
            return jsonify({"error": "Missing image or dress data"}), 400

        img_bytes = base64.b64decode(data.split(",")[1])
        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)

        dress_bytes = base64.b64decode(dress_data.split(",")[1])
        dress_img = cv2.imdecode(np.frombuffer(dress_bytes, np.uint8), cv2.IMREAD_UNCHANGED)

        results = pose.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        landmarks = results.pose_landmarks.landmark if results.pose_landmarks else None

        frame = overlay_dress(frame, dress_img, landmarks)
        if results.pose_landmarks:
            mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

        _, buffer = cv2.imencode(".jpg", frame)
        img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffer).decode()

        return jsonify({"image": img_base64})
    except Exception as e:
        print("❌ /tryon error:", e)
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


if __name__ == "__main__":
    import os
    port = int(os.environ.get("PORT", 7860))  # πŸ‘ˆ use 7860 (Hugging Face default)
    app.run(host="0.0.0.0", port=port)