FaceGNN / CODE /pyTorch_FaceNet_api.py
aiyubali's picture
FaceGNN Updated v1.1
19ea92a
import torch
import asyncio
import traceback
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aiohttp import ClientSession
from PIL import Image, ImageFilter
from io import BytesIO
from facenet_pytorch import MTCNN, InceptionResnetV1
import uvicorn
import os
class FaceMainExecutionPytorch:
def __init__(self):
self.MVL = 0.90
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.detector = MTCNN()
self.resnet = InceptionResnetV1(pretrained="vggface2").eval().to(self.device)
async def calculate_blurriness(self, img: Image.Image) -> float:
if img.mode != "L":
img = img.convert("L")
laplacian = img.filter(ImageFilter.FIND_EDGES)
laplacian_array = np.array(laplacian, dtype=np.float32)
return laplacian_array.var()
async def extract_face(self, image, detector, size=(160, 160)):
"""Detect and extract the face from an image object."""
try:
boxes, _ = self.detector.detect(np.array(image))
if boxes is None or len(boxes) == 0:
print("No face detected in the image.")
return None
x, y, width, height = boxes[0]
x, y, width, height = int(x), int(y), int(width), int(height)
x, y = max(x, 0), max(y, 0)
face = image.crop((x, y, x + width, y + height))
face = face.resize(size)
return face
except Exception as e:
print(f"Error extracting face: {e}")
traceback.print_exc()
return None
async def calculate_embedding(self, face_image):
"""Calculate the face embedding."""
try:
resnet = self.resnet
face_tensor = torch.tensor(np.array(face_image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
face_tensor = (face_tensor - 0.5) / 0.5 # Normalize to [-1, 1]
face_tensor = face_tensor.to(self.device)
embedding = resnet(face_tensor)
return embedding.detach().cpu().numpy()
except Exception as e:
print(f"Error calculating embedding: {e}")
traceback.print_exc()
return None
async def compare_faces(self, embedding1, embedding2):
"""Compare two face embeddings and return similarity."""
distance = np.linalg.norm(embedding1 - embedding2)
distance = round((distance), 2)
print(f"Cosine Distance: {distance}")
return distance < self.MVL
async def get_image_from_url(self, img_path):
"""Download or load an image from a local path or URL."""
try:
# Check if the input is a URL
if img_path.startswith("http://") or img_path.startswith("https://"):
async with ClientSession() as session:
async with session.get(img_path) as response:
if response.status != 200:
raise ValueError(f"HTTP Error: {response.status}")
img_data = await response.read()
image = Image.open(BytesIO(img_data))
image.verify()
image = Image.open(BytesIO(img_data)).convert("RGB")
elif os.path.isfile(img_path):
# Load image from a local file
image = Image.open(img_path).convert("RGB")
else:
raise ValueError("Invalid input path. Must be a valid URL or local file path.")
blur = await self.calculate_blurriness(image)
if blur and blur < 30:
print(f"===================>>>>>>>>>> Blurry: {blur}")
return None
else:
print('Good Image: ', blur)
return image
except Exception as e:
print(f"Error downloading or loading image: {e}")
traceback.print_exc()
return None
async def compare_faces_from_urls(self, url1, url2):
"""Compare faces from two image URLs."""
try:
task = [
asyncio.create_task(self.get_image_from_url(url1)),
asyncio.create_task(self.get_image_from_url(url2))
]
img1, img2 = await asyncio.gather(*task)
if img1 is None or img2 is None:
return False
face1 = await self.extract_face(img1, self.detector)
face2 = await self.extract_face(img2, self.detector)
if face1 is None or face2 is None:
return False
embedding1 = await self.calculate_embedding(face1)
embedding2 = await self.calculate_embedding(face2)
if embedding1 is None or embedding2 is None:
return False
return await self.compare_faces(embedding1, embedding2)
except Exception as e:
print(f"Error comparing faces from URLs: {e}")
traceback.print_exc()
return False
async def faceMain(self, BODY):
try:
result = await self.compare_faces_from_urls(BODY['img1'], BODY['img2'])
if result:
return True
else:
return False
except Exception as e:
traceback.print_exc()
return False
# FastAPI server setup
app = FastAPI()
class Body(BaseModel):
img1: str
img2: str
# Initialize the face comparison class
face_executor = FaceMainExecutionPytorch()
@app.post("/compare_faces")
async def compare_faces(body: Body):
try:
result = await face_executor.faceMain(body.dict())
print(result)
if result:
return True
else:
return False
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
try:
uvicorn.run(app, host="127.0.0.1", port=8888)
finally:
torch.cuda.empty_cache() # Clear CUDA memory after the server stops