ritz26 commited on
Commit
9807323
·
1 Parent(s): 487cad7

Upload a virtual try on

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Ignore user uploads and generated output
2
+ static/uploads/*
3
+ static/outputs/*
4
+
5
+ # But keep the directories themselves
6
+ !static/uploads/.gitkeep
7
+ !static/outputs/.gitkeep
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Set cache directories
6
+ ENV HF_HOME=/app/.cache
7
+ ENV MPLCONFIGDIR=/app/.cache
8
+
9
+ # Install dependencies
10
+ RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 && apt-get clean
11
+
12
+ # Create cache directory
13
+ RUN mkdir -p /app/.cache && chmod -R 777 /app/.cache
14
+
15
+ # Install Python dependencies
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Pre-download model
20
+ RUN python -c "from transformers import SamModel, SamProcessor; \
21
+ SamModel.from_pretrained('Zigeng/SlimSAM-uniform-50', cache_dir='/app/.cache'); \
22
+ SamProcessor.from_pretrained('Zigeng/SlimSAM-uniform-50', cache_dir='/app/.cache')"
23
+
24
+ # Copy app code
25
+ COPY . .
26
+
27
+ # Set port
28
+ ENV PORT=7860
29
+
30
+ # Run with Gunicorn
31
+ CMD ["gunicorn", "app:app", "--bind", "0.0.0.0:7860", "--workers", "2"]
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, send_from_directory
2
+ from PIL import Image
3
+ import os, torch, cv2, mediapipe as mp
4
+ from transformers import SamModel, SamProcessor, logging as hf_logging
5
+ from torchvision import transforms
6
+ from diffusers.utils import load_image
7
+ from flask_cors import CORS
8
+
9
+ app= Flask(__name__)
10
+ CORS(app)
11
+
12
+ # Enable Hugging Face detailed logs (shows model download progress)
13
+ hf_logging.set_verbosity_info()
14
+
15
+
16
+ UPLOAD_FOLDER = '/tmp/uploads'
17
+ OUTPUT_FOLDER = '/tmp/outputs'
18
+
19
+ if not os.path.exists(UPLOAD_FOLDER):
20
+ print(f"[WARN] {UPLOAD_FOLDER} does not exist. Creating...")
21
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
22
+
23
+ if not os.path.exists(OUTPUT_FOLDER):
24
+ print(f"[WARN] {OUTPUT_FOLDER} does not exist. Creating...")
25
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
26
+
27
+
28
+ # Lazy-load model
29
+ model, processor = None, None
30
+
31
+ def load_model():
32
+ global model, processor
33
+ if model is None or processor is None:
34
+ print("[INFO] Loading SAM model and processor...")
35
+ model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache")
36
+ processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache")
37
+ print("[INFO] Model and processor loaded successfully!")
38
+
39
+ @app.before_request
40
+ def log_request_info():
41
+ print(f"[INFO] Incoming request: {request.method} {request.path}")
42
+
43
+ @app.route('/health')
44
+ def health():
45
+ return "OK", 200
46
+
47
+ # Route to serve outputs dynamically
48
+ @app.route('/outputs/<filename>')
49
+ def serve_output(filename):
50
+ return send_from_directory(OUTPUT_FOLDER, filename)
51
+
52
+ @app.route('/', methods=['GET', 'POST'])
53
+ def index():
54
+ print(f"[INFO] Handling {request.method} on /")
55
+ if request.method == 'POST':
56
+ try:
57
+ load_model()
58
+
59
+ # Save uploaded images
60
+ person_file = request.files['person_image']
61
+ tshirt_file = request.files['tshirt_image']
62
+ person_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
63
+ tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
64
+ person_file.save(person_path)
65
+ tshirt_file.save(tshirt_path)
66
+ print(f"[INFO] Saved files to {UPLOAD_FOLDER}")
67
+
68
+ # Pose detection
69
+ mp_pose = mp.solutions.pose
70
+ pose = mp_pose.Pose()
71
+ image = cv2.imread(person_path)
72
+ if image is None:
73
+ return "No image detected."
74
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+ results = pose.process(image_rgb)
76
+ if not results.pose_landmarks:
77
+ return "No pose detected."
78
+ height, width, _ = image.shape
79
+ landmarks = results.pose_landmarks.landmark
80
+ left_shoulder = (int(landmarks[11].x * width), int(landmarks[11].y * height))
81
+ right_shoulder = (int(landmarks[12].x * width), int(landmarks[12].y * height))
82
+ print(f"[INFO] Shoulder coordinates: {left_shoulder}, {right_shoulder}")
83
+
84
+ # SAM model inference
85
+ img = load_image(person_path)
86
+ new_tshirt = load_image(tshirt_path)
87
+ input_points = [[[left_shoulder[0], left_shoulder[1]], [right_shoulder[0], right_shoulder[1]]]]
88
+ inputs = processor(img, input_points=input_points, return_tensors="pt")
89
+ outputs = model(**inputs)
90
+ masks = processor.image_processor.post_process_masks(
91
+ outputs.pred_masks.cpu(),
92
+ inputs["original_sizes"].cpu(),
93
+ inputs["reshaped_input_sizes"].cpu()
94
+ )
95
+ mask_tensor = masks[0][0][2].to(dtype=torch.uint8)
96
+ mask = transforms.ToPILImage()(mask_tensor * 255)
97
+
98
+ # Combine images
99
+ new_tshirt = new_tshirt.resize(img.size, Image.LANCZOS)
100
+ img_with_new_tshirt = Image.composite(new_tshirt, img, mask)
101
+ result_path = os.path.join(OUTPUT_FOLDER, 'result.jpg')
102
+ img_with_new_tshirt.save(result_path)
103
+ print(f"[INFO] Result saved to {result_path}")
104
+
105
+ # Serve via dynamic route
106
+ return render_template('index.html', result_img='/outputs/result.jpg')
107
+
108
+ except Exception as e:
109
+ print(f"[ERROR] {e}")
110
+ return f"Error: {e}"
111
+
112
+ return render_template('index.html')
113
+
114
+ if __name__ == '__main__':
115
+
116
+ print("[INFO] Starting Flask server...")
117
+ app.run(debug=true, host='0.0.0.0')
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask
2
+ gunicorn
3
+ Pillow
4
+ opencv-python
5
+ torch
6
+ torchvision
7
+ mediapipe
8
+ transformers
9
+ diffusers
10
+ safetensors
11
+ flask-cors
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10.12
static/outputs/.gitkeep ADDED
File without changes
static/uploads/.gitkeep ADDED
File without changes
templates/index.html ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <title>Virtual Fashion Try-On</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ </head>
9
+
10
+ <body class="bg-gray-900 text-white min-h-screen flex flex-col items-center py-10">
11
+
12
+ <h1 class="text-4xl font-bold text-blue-400 mb-10">Virtual Fashion Try-On</h1>
13
+
14
+ <form action="/" method="post" enctype="multipart/form-data" class="w-full max-w-4xl bg-gray-800 rounded-2xl shadow-lg p-8 space-y-6">
15
+
16
+ <div class="grid grid-cols-1 md:grid-cols-2 gap-8">
17
+ <!-- Person Image Upload -->
18
+ <div>
19
+ <h2 class="text-lg font-semibold mb-2">Upload your photo</h2>
20
+ <label for="person_image" class="flex flex-col items-center justify-center border-2 border-dashed border-gray-600 rounded-xl p-6 hover:bg-gray-700 cursor-pointer relative">
21
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-10 w-10 text-gray-400 mb-2" fill="none" viewBox="0 0 24 24" stroke="currentColor">
22
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16v4m0 0h10m-10 0v-4m0 0h10m-10 0V5m0 0h10m-10 0H5m14 0h-2" />
23
+ </svg>
24
+ <p class="text-gray-400">Drag & drop or click to upload</p>
25
+ <input id="person_image" type="file" name="person_image" class="hidden" required onchange="showFileName('person_image', 'person_filename')">
26
+ </label>
27
+ <p id="person_filename" class="text-green-400 text-sm mt-2 text-center"></p>
28
+ </div>
29
+
30
+ <!-- T-Shirt Image Upload -->
31
+ <div>
32
+ <h2 class="text-lg font-semibold mb-2">Upload garment image</h2>
33
+ <label for="tshirt_image" class="flex flex-col items-center justify-center border-2 border-dashed border-gray-600 rounded-xl p-6 hover:bg-gray-700 cursor-pointer relative">
34
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-10 w-10 text-gray-400 mb-2" fill="none" viewBox="0 0 24 24" stroke="currentColor">
35
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16v4m0 0h10m-10 0v-4m0 0h10m-10 0V5m0 0h10m-10 0H5m14 0h-2" />
36
+ </svg>
37
+ <p class="text-gray-400">Drag & drop or click to upload</p>
38
+ <input id="tshirt_image" type="file" name="tshirt_image" class="hidden" required onchange="showFileName('tshirt_image', 'tshirt_filename')">
39
+ </label>
40
+ <p id="tshirt_filename" class="text-green-400 text-sm mt-2 text-center"></p>
41
+ </div>
42
+ </div>
43
+
44
+ <!-- Submit Button -->
45
+ <div class="flex justify-center">
46
+ <button type="submit" class="bg-pink-500 hover:bg-pink-600 text-white font-semibold py-3 px-8 rounded-xl shadow-md transition">
47
+ 🚀 Perform Virtual Try-On
48
+ </button>
49
+ </div>
50
+
51
+ </form>
52
+
53
+ {% if result_img %}
54
+ <div class="w-full max-w-4xl mt-10 bg-gray-800 rounded-2xl shadow-lg p-8">
55
+ <h2 class="text-2xl font-bold mb-6 text-center">🎉 Your Virtual Try-On Result</h2>
56
+ <div class="flex justify-center">
57
+ <img src="{{ result_img }}" alt="Result Image">
58
+ </div>
59
+ </div>
60
+ {% endif %}
61
+
62
+ <script>
63
+ function showFileName(inputId, filenameId) {
64
+ const input = document.getElementById(inputId);
65
+ const filename = document.getElementById(filenameId);
66
+ if (input.files.length > 0) {
67
+ filename.textContent = "✔️ " + input.files[0].name + " uploaded";
68
+ } else {
69
+ filename.textContent = "";
70
+ }
71
+ }
72
+ </script>
73
+
74
+ </body>
75
+
76
+ </html>