Midnightar commited on
Commit
18c23ed
·
verified ·
1 Parent(s): 5f2c279

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -39
app.py CHANGED
@@ -4,20 +4,16 @@ import numpy as np
4
  import requests
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
 
 
7
  import insightface
8
  import gradio as gr
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.staticfiles import StaticFiles
11
 
12
- # -------------------------------------------
13
- # Load Face Model
14
- # -------------------------------------------
15
  model = insightface.app.FaceAnalysis(name="buffalo_l")
16
  model.prepare(ctx_id=0, det_size=(640, 640))
17
 
18
- # -------------------------------------------
19
- # FastAPI app
20
- # -------------------------------------------
21
  app = FastAPI()
22
 
23
  # CORS for FlutterFlow
@@ -30,9 +26,7 @@ app.add_middleware(
30
  )
31
 
32
 
33
- # -------------------------------------------
34
- # Request Schema
35
- # -------------------------------------------
36
  class CompareRequest(BaseModel):
37
  image1: str | None = None
38
  image2: str | None = None
@@ -40,9 +34,7 @@ class CompareRequest(BaseModel):
40
  image2_url: str | None = None
41
 
42
 
43
- # -------------------------------------------
44
- # Helper Functions
45
- # -------------------------------------------
46
  def b64_to_img(b64_string):
47
  try:
48
  img_data = base64.b64decode(b64_string)
@@ -52,7 +44,6 @@ def b64_to_img(b64_string):
52
  except:
53
  return None
54
 
55
-
56
  def url_to_img(url):
57
  try:
58
  resp = requests.get(url, timeout=5)
@@ -62,7 +53,6 @@ def url_to_img(url):
62
  except:
63
  return None
64
 
65
-
66
  def get_embedding(img):
67
  faces = model.get(img)
68
  if len(faces) == 0:
@@ -70,9 +60,7 @@ def get_embedding(img):
70
  return faces[0].embedding
71
 
72
 
73
- # -------------------------------------------
74
- # API Endpoint
75
- # -------------------------------------------
76
  @app.post("/compare")
77
  async def compare_faces(req: CompareRequest):
78
 
@@ -93,35 +81,38 @@ async def compare_faces(req: CompareRequest):
93
  img2 = None
94
 
95
  if img1 is None or img2 is None:
96
- return {"error": "Invalid image data or URL"}
97
 
98
  emb1 = get_embedding(img1)
99
  emb2 = get_embedding(img2)
100
 
101
  if emb1 is None or emb2 is None:
102
- return {"error": "No face detected"}
103
 
104
- similarity = float(np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)))
 
 
105
  matched = similarity > 0.55
106
 
107
- return {"similarity": similarity, "match": matched}
 
 
 
108
 
109
 
110
- # -------------------------------------------
111
- # Gradio UI (embedded inside FastAPI)
112
- # -------------------------------------------
113
  def gradio_ui(img1_text, img2_text):
114
-
115
- def load_any(s):
116
  if s.startswith("http"):
117
  return url_to_img(s)
118
  return b64_to_img(s)
119
 
120
- img1 = load_any(img1_text)
121
- img2 = load_any(img2_text)
122
 
123
  if img1 is None or img2 is None:
124
- return "Invalid image or URL"
125
 
126
  emb1 = get_embedding(img1)
127
  emb2 = get_embedding(img2)
@@ -129,20 +120,25 @@ def gradio_ui(img1_text, img2_text):
129
  if emb1 is None or emb2 is None:
130
  return "Face not detected."
131
 
132
- similarity = float(np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)))
133
- matched = similarity > 0.55
 
 
134
 
135
- return f"Similarity: {similarity:.3f} | Match: {matched}"
136
 
137
 
138
- ui = gr.Interface(
139
  fn=gradio_ui,
140
  inputs=[
141
- gr.Textbox(label="Image 1 (base64 or URL)"),
142
- gr.Textbox(label="Image 2 (base64 or URL)"),
143
  ],
144
  outputs="text",
145
- title="Face Match API",
146
  )
147
 
148
- app = gr.mount_gradio_app(app, ui, path="/")
 
 
 
 
4
  import requests
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
+ from fastapi.responses import HTMLResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  import insightface
10
  import gradio as gr
 
 
11
 
12
+ # ---------- Load Face Detector ----------
 
 
13
  model = insightface.app.FaceAnalysis(name="buffalo_l")
14
  model.prepare(ctx_id=0, det_size=(640, 640))
15
 
16
+ # ---------- FastAPI App ----------
 
 
17
  app = FastAPI()
18
 
19
  # CORS for FlutterFlow
 
26
  )
27
 
28
 
29
+ # ---------- API Request Model ----------
 
 
30
  class CompareRequest(BaseModel):
31
  image1: str | None = None
32
  image2: str | None = None
 
34
  image2_url: str | None = None
35
 
36
 
37
+ # ---------- Helpers ----------
 
 
38
  def b64_to_img(b64_string):
39
  try:
40
  img_data = base64.b64decode(b64_string)
 
44
  except:
45
  return None
46
 
 
47
  def url_to_img(url):
48
  try:
49
  resp = requests.get(url, timeout=5)
 
53
  except:
54
  return None
55
 
 
56
  def get_embedding(img):
57
  faces = model.get(img)
58
  if len(faces) == 0:
 
60
  return faces[0].embedding
61
 
62
 
63
+ # ---------- POST /compare ----------
 
 
64
  @app.post("/compare")
65
  async def compare_faces(req: CompareRequest):
66
 
 
81
  img2 = None
82
 
83
  if img1 is None or img2 is None:
84
+ return {"error": "Invalid image data or URL."}
85
 
86
  emb1 = get_embedding(img1)
87
  emb2 = get_embedding(img2)
88
 
89
  if emb1 is None or emb2 is None:
90
+ return {"error": "No face detected."}
91
 
92
+ similarity = np.dot(emb1, emb2) / (
93
+ np.linalg.norm(emb1) * np.linalg.norm(emb2)
94
+ )
95
  matched = similarity > 0.55
96
 
97
+ return {
98
+ "similarity": float(similarity),
99
+ "match": matched
100
+ }
101
 
102
 
103
+ # ---------- Gradio UI ----------
 
 
104
  def gradio_ui(img1_text, img2_text):
105
+ # Detect format automatically
106
+ def load(s):
107
  if s.startswith("http"):
108
  return url_to_img(s)
109
  return b64_to_img(s)
110
 
111
+ img1 = load(img1_text)
112
+ img2 = load(img2_text)
113
 
114
  if img1 is None or img2 is None:
115
+ return "Invalid image."
116
 
117
  emb1 = get_embedding(img1)
118
  emb2 = get_embedding(img2)
 
120
  if emb1 is None or emb2 is None:
121
  return "Face not detected."
122
 
123
+ similarity = np.dot(emb1, emb2) / (
124
+ np.linalg.norm(emb1) * np.linalg.norm(emb2)
125
+ )
126
+ match = similarity > 0.55
127
 
128
+ return f"Similarity: {similarity:.3f} | Match: {match}"
129
 
130
 
131
+ demo = gr.Interface(
132
  fn=gradio_ui,
133
  inputs=[
134
+ gr.Textbox(label="Image 1 (URL or Base64)"),
135
+ gr.Textbox(label="Image 2 (URL or Base64)")
136
  ],
137
  outputs="text",
138
+ title="Face Compare API"
139
  )
140
 
141
+ # ---------- Serve Gradio UI at "/" ----------
142
+ @app.get("/", response_class=HTMLResponse)
143
+ async def root():
144
+ return demo.launch(share=False, inline=True)