LiamKhoaLe commited on
Commit
1e21006
·
0 Parent(s):

Initial commit

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. Dockerfile +23 -0
  3. README.md +11 -0
  4. app.py +136 -0
  5. requirements.txt +10 -0
  6. static/index.html +44 -0
  7. static/script.js +129 -0
  8. static/styles.css +119 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ models
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # ── Create and switch to non‑root user ──
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+
7
+ # ── Set environment and working directory ──
8
+ ENV HOME=/home/user
9
+ WORKDIR $HOME/app
10
+
11
+ # ── Upgrade pip and install dependencies ──
12
+ COPY --chown=user requirements.txt .
13
+ RUN pip install --upgrade pip && \
14
+ pip install --no-cache-dir -r requirements.txt
15
+
16
+ # ── Copy application source ──
17
+ COPY --chown=user . .
18
+
19
+ # ── Create cache / log folders with correct permissions ──
20
+ RUN mkdir -p $HOME/app/cache $HOME/app/logs
21
+
22
+ EXPOSE 7860
23
+ CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Similarity Checker
3
+ emoji: 😃
4
+ colorFrom: red
5
+ colorTo: green
6
+ sdk: docker
7
+ sdk_version: latest
8
+ pinned: false
9
+ license: apache-2.0
10
+ short_description: Similarity comparison between facial images using self-trained TripletLoss model (CNN) and Cosine metric
11
+ ---
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tempfile import TemporaryDirectory
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+ from fastapi import FastAPI, File, UploadFile, HTTPException
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import JSONResponse, FileResponse
13
+ from fastapi.staticfiles import StaticFiles
14
+ from torchvision import transforms
15
+
16
+ # ─────────────────────────────────────────────────────────────────────────────
17
+ # Model Definition (Triplet)
18
+ # ─────────────────────────────────────────────────────────────────────────────
19
+ class BaseCNN(nn.Module):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.conv = nn.Sequential(
23
+ nn.Conv2d(3, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
24
+ nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
25
+ nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
26
+ )
27
+
28
+ def forward(self, x):
29
+ return self.conv(x).view(x.size(0), -1)
30
+
31
+ EMBED_DIM = 128
32
+ embed_head = nn.Sequential(
33
+ nn.Linear(128, EMBED_DIM),
34
+ nn.BatchNorm1d(EMBED_DIM),
35
+ nn.ReLU(),
36
+ )
37
+
38
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ tri_model = nn.Sequential(BaseCNN(), embed_head).to(DEVICE)
41
+ model_path = Path("models/tri_model.pth")
42
+ if not model_path.exists():
43
+ raise FileNotFoundError("Model weights not found at models/tri_model.pth")
44
+ tri_model.load_state_dict(torch.load(model_path, map_location=DEVICE))
45
+ tri_model.eval()
46
+
47
+ transform = transforms.Compose(
48
+ [
49
+ transforms.Resize((64, 64)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
52
+ ]
53
+ )
54
+
55
+ # ─────────────────────────────────────────────────────────────────────────────
56
+ # FastAPI App Setup
57
+ # ─────────────────────────────────────────────────────────────────────────────
58
+ app = FastAPI(title="Face Similarity Comparator", docs_url="/docs")
59
+
60
+ # Allow frontend to call backend when served from the same origin in Spaces
61
+ app.add_middleware(
62
+ CORSMiddleware,
63
+ allow_origins=["*"],
64
+ allow_methods=["*"],
65
+ allow_headers=["*"],
66
+ )
67
+
68
+ # Serve static files (index.html, JS, CSS)
69
+ static_dir = Path(__file__).parent / "static"
70
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
71
+
72
+ # Redirect root to frontend
73
+ @app.get("/")
74
+ async def root_index():
75
+ return FileResponse(static_dir / "index.html")
76
+
77
+ # ─────────────────────────────────────────────────────────────────────────────
78
+ # Helper Functions
79
+ # ─────────────────────────────────────────────────────────────────────────────
80
+
81
+ def bytes_to_cv2(bytes_data: bytes) -> np.ndarray:
82
+ img_array = np.frombuffer(bytes_data, np.uint8)
83
+ img_bgr = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
84
+ if img_bgr is None:
85
+ raise ValueError("Invalid image data")
86
+ return img_bgr
87
+
88
+
89
+ def get_embedding(img_bgr: np.ndarray) -> np.ndarray:
90
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
91
+ x = transform(Image.fromarray(img_rgb)).unsqueeze(0).to(DEVICE)
92
+ with torch.no_grad():
93
+ emb = tri_model(x)
94
+ emb = emb.cpu().numpy()[0]
95
+ return emb / np.linalg.norm(emb)
96
+
97
+
98
+ def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
99
+ return float(np.dot(vec1, vec2))
100
+
101
+ # ─────────────────────────────────────────────────────────────────────────────
102
+ # API Endpoint
103
+ # ─────────────────────────────────────────────────────────────────────────────
104
+
105
+ @app.post("/compare")
106
+ async def compare_faces(
107
+ tester: UploadFile = File(...),
108
+ samplers: list[UploadFile] = File(...),
109
+ ):
110
+ """Compare tester image against up to 5 sampler images and return similarity."""
111
+
112
+ if len(samplers) == 0:
113
+ raise HTTPException(status_code=400, detail="At least one sampler image is required.")
114
+ if len(samplers) > 5:
115
+ raise HTTPException(status_code=400, detail="You can upload up to 5 sampler images only.")
116
+
117
+ # Read tester image
118
+ tester_bytes = await tester.read()
119
+ tester_img = bytes_to_cv2(tester_bytes)
120
+ tester_emb = get_embedding(tester_img)
121
+
122
+ results = []
123
+ for smp in samplers:
124
+ smp_bytes = await smp.read()
125
+ smp_img = bytes_to_cv2(smp_bytes)
126
+ smp_emb = get_embedding(smp_img)
127
+ sim = cosine_similarity(tester_emb, smp_emb)
128
+ results.append({
129
+ "name": smp.filename,
130
+ "similarity": round(sim * 100, 2),
131
+ })
132
+
133
+ # Sort high → low
134
+ results.sort(key=lambda x: x["similarity"], reverse=True)
135
+
136
+ return JSONResponse({"tester": tester.filename, "results": results})
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ python-multipart
4
+ pillow
5
+ torch
6
+ torchvision
7
+ opencv-python-headless
8
+ numpy
9
+ aiofiles
10
+ jinja2
static/index.html ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Face Similarity Comparator</title>
7
+ <link rel="stylesheet" href="/static/styles.css" />
8
+ </head>
9
+ <body>
10
+ <header>
11
+ <h1>👨‍👩‍👧‍👦 Face Similarity Comparator</h1>
12
+ <p>
13
+ Curious who your kids resemble the most? Wonder which celebrity shares
14
+ your features? Upload photos below and let our AI judge facial
15
+ similarity for fun family debates, twin‑spotting games, or social‑media
16
+ challenges. For best accuracy use front‑facing photos, evenly lit with
17
+ no glasses, masks, or hair covering facial landmarks.
18
+ </p>
19
+ </header>
20
+
21
+ <main id="upload-section">
22
+ <div class="upload-wrapper">
23
+ <h2>Tester Photo</h2>
24
+ <input id="tester" type="file" accept="image/*" />
25
+ <small>Upload <strong>one-only</strong> clear, front‑view face.</small>
26
+ </div>
27
+
28
+ <div class="upload-wrapper">
29
+ <h2>Sampler Photos (max 5)</h2>
30
+ <input id="samplers" type="file" accept="image/*" multiple />
31
+ <small>Upload up to five comparison faces.</small>
32
+ </div>
33
+ </main>
34
+
35
+ <div id="action-container"></div>
36
+
37
+ <section id="results" class="hidden">
38
+ <canvas id="lines-canvas"></canvas>
39
+ <!-- squares injected by JS -->
40
+ </section>
41
+
42
+ <script src="/static/script.js"></script>
43
+ </body>
44
+ </html>
static/script.js ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const testerInput = document.getElementById("tester");
2
+ const samplerInput = document.getElementById("samplers");
3
+ const actionContainer = document.getElementById("action-container");
4
+ const resultsSection = document.getElementById("results");
5
+ const canvas = document.getElementById("lines-canvas");
6
+ const ctx = canvas.getContext("2d");
7
+
8
+ let generateBtn, loaderDiv;
9
+
10
+ function resetUI() {
11
+ resultsSection.classList.add("hidden");
12
+ resultsSection.innerHTML = "<canvas id=\"lines-canvas\"></canvas>";
13
+ canvas.width = canvas.height = 0;
14
+ if (loaderDiv) loaderDiv.remove();
15
+ if (generateBtn) generateBtn.remove();
16
+ }
17
+
18
+ testerInput.addEventListener("change", handleFileChange);
19
+ samplerInput.addEventListener("change", handleFileChange);
20
+
21
+ function handleFileChange() {
22
+ resetUI();
23
+ const testerReady = testerInput.files.length === 1;
24
+ const samplersReady =
25
+ samplerInput.files.length > 0 && samplerInput.files.length <= 5;
26
+
27
+ if (samplerInput.files.length > 5) {
28
+ alert("You can upload a maximum of 5 sampler photos.");
29
+ samplerInput.value = "";
30
+ return;
31
+ }
32
+
33
+ if (testerReady && samplersReady) {
34
+ generateBtn = document.createElement("button");
35
+ generateBtn.textContent = "Generate Comparison";
36
+ generateBtn.className = "generate-btn";
37
+ generateBtn.onclick = sendComparison;
38
+ actionContainer.appendChild(generateBtn);
39
+ }
40
+ }
41
+
42
+ function showLoader() {
43
+ loaderDiv = document.createElement("div");
44
+ loaderDiv.className = "loader";
45
+ actionContainer.appendChild(loaderDiv);
46
+ }
47
+
48
+ async function sendComparison() {
49
+ generateBtn.disabled = true;
50
+ showLoader();
51
+
52
+ const formData = new FormData();
53
+ formData.append("tester", testerInput.files[0]);
54
+ for (let file of samplerInput.files) formData.append("samplers", file);
55
+
56
+ try {
57
+ const res = await fetch("/compare", { method: "POST", body: formData });
58
+ if (!res.ok) throw new Error(await res.text());
59
+ const data = await res.json();
60
+ renderResults(data);
61
+ } catch (err) {
62
+ alert("Error: " + err.message);
63
+ } finally {
64
+ generateBtn.disabled = false;
65
+ loaderDiv.remove();
66
+ }
67
+ }
68
+
69
+ function getBorderColor(percent) {
70
+ if (percent < 30) return "#ef4444"; // red
71
+ if (percent <= 50) return "#fb923c"; // orange
72
+ if (percent <= 80) return "#22c55e"; // green
73
+ return "#a855f7"; // purple
74
+ }
75
+
76
+ function renderResults({ tester, results }) {
77
+ resultsSection.classList.remove("hidden");
78
+ const testerSquare = createSquare(testerInput.files[0], "#3b82f6");
79
+ resultsSection.appendChild(testerSquare);
80
+
81
+ results.forEach((r, idx) => {
82
+ const square = createSquare(samplerInput.files[idx], getBorderColor(r.similarity));
83
+ resultsSection.appendChild(square);
84
+ drawLineBetween(testerSquare, square, r.similarity);
85
+ });
86
+ }
87
+
88
+ function createSquare(file, borderColor) {
89
+ const url = URL.createObjectURL(file);
90
+ const div = document.createElement("div");
91
+ div.className = "square";
92
+ div.style.border = `4px solid ${borderColor}`;
93
+ const img = document.createElement("img");
94
+ img.src = url;
95
+ div.appendChild(img);
96
+ return div;
97
+ }
98
+
99
+ function drawLineBetween(el1, el2, similarity) {
100
+ const rect1 = el1.getBoundingClientRect();
101
+ const rect2 = el2.getBoundingClientRect();
102
+ const offset = resultsSection.getBoundingClientRect();
103
+
104
+ // Compute canvas size on first call
105
+ if (canvas.width === 0) {
106
+ canvas.width = resultsSection.scrollWidth;
107
+ canvas.height = resultsSection.getBoundingClientRect().height;
108
+ }
109
+
110
+ const x1 = rect1.right - offset.left;
111
+ const y1 = rect1.top + rect1.height / 2 - offset.top;
112
+ const x2 = rect2.left - offset.left;
113
+ const y2 = rect2.top + rect2.height / 2 - offset.top;
114
+
115
+ ctx.strokeStyle = getBorderColor(similarity);
116
+ ctx.lineWidth = 3;
117
+ ctx.beginPath();
118
+ ctx.moveTo(x1, y1);
119
+ ctx.lineTo(x2, y2);
120
+ ctx.stroke();
121
+
122
+ // Label
123
+ const label = document.createElement("div");
124
+ label.className = "label";
125
+ label.textContent = similarity.toFixed(0) + "%";
126
+ label.style.left = (x1 + x2) / 2 + "px";
127
+ label.style.top = (y1 + y2) / 2 + "px";
128
+ resultsSection.appendChild(label);
129
+ }
static/styles.css ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ margin: 0;
4
+ padding: 0;
5
+ font-family: system-ui, sans-serif;
6
+ }
7
+
8
+ body {
9
+ display: flex;
10
+ flex-direction: column;
11
+ align-items: center;
12
+ gap: 1.5rem;
13
+ padding: 1rem;
14
+ max-width: 960px;
15
+ margin-inline: auto;
16
+ }
17
+
18
+ header {
19
+ text-align: center;
20
+ }
21
+
22
+ #upload-section {
23
+ display: flex;
24
+ flex-wrap: wrap;
25
+ gap: 2rem;
26
+ justify-content: center;
27
+ width: 100%;
28
+ }
29
+
30
+ .upload-wrapper {
31
+ border: 2px dashed #aaa;
32
+ border-radius: 0.75rem;
33
+ padding: 1rem;
34
+ flex: 1 1 280px;
35
+ max-width: 420px;
36
+ text-align: center;
37
+ }
38
+
39
+ input[type="file"] {
40
+ margin-block: 0.75rem;
41
+ }
42
+
43
+ #action-container {
44
+ margin-top: 1rem;
45
+ }
46
+
47
+ button.generate-btn {
48
+ padding: 0.5rem 1.25rem;
49
+ font-size: 1.1rem;
50
+ background: #1e88e5;
51
+ color: #fff;
52
+ border: none;
53
+ border-radius: 0.5rem;
54
+ cursor: pointer;
55
+ }
56
+ button.generate-btn:disabled {
57
+ opacity: 0.6;
58
+ cursor: not-allowed;
59
+ }
60
+
61
+ .loader {
62
+ border: 4px solid #f3f3f3;
63
+ border-top: 4px solid #1e88e5;
64
+ border-radius: 50%;
65
+ width: 42px;
66
+ height: 42px;
67
+ animation: spin 1s linear infinite;
68
+ margin-inline: auto;
69
+ margin-top: 1rem;
70
+ }
71
+ @keyframes spin {
72
+ 0% { transform: rotate(0deg); }
73
+ 100% { transform: rotate(360deg); }
74
+ }
75
+
76
+ #results {
77
+ position: relative;
78
+ width: 100%;
79
+ margin-top: 2rem;
80
+ display: flex;
81
+ gap: 2rem;
82
+ justify-content: flex-start;
83
+ overflow-x: auto;
84
+ }
85
+
86
+ .square {
87
+ width: 120px;
88
+ height: 120px;
89
+ border-radius: 0.5rem;
90
+ overflow: hidden;
91
+ position: relative;
92
+ flex-shrink: 0;
93
+ }
94
+ .square img {
95
+ width: 100%;
96
+ height: 100%;
97
+ object-fit: cover;
98
+ }
99
+
100
+ #lines-canvas {
101
+ position: absolute;
102
+ top: 0;
103
+ left: 0;
104
+ pointer-events: none;
105
+ }
106
+
107
+ .label {
108
+ position: absolute;
109
+ background: rgba(0, 0, 0, 0.7);
110
+ color: #fff;
111
+ padding: 2px 6px;
112
+ border-radius: 4px;
113
+ font-size: 0.8rem;
114
+ transform: translate(-50%, -50%);
115
+ }
116
+
117
+ @media (max-width: 600px) {
118
+ .square { width: 90px; height: 90px; }
119
+ }