File size: 6,265 Bytes
19ea92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import asyncio
import traceback
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aiohttp import ClientSession
from PIL import Image, ImageFilter
from io import BytesIO
from facenet_pytorch import MTCNN, InceptionResnetV1
import uvicorn
import os

class FaceMainExecutionPytorch:
    def __init__(self):
        self.MVL = 0.90
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.detector = MTCNN()
        self.resnet = InceptionResnetV1(pretrained="vggface2").eval().to(self.device)

    async def calculate_blurriness(self, img: Image.Image) -> float:
        if img.mode != "L":
            img = img.convert("L")
        laplacian = img.filter(ImageFilter.FIND_EDGES)
        laplacian_array = np.array(laplacian, dtype=np.float32)
        return laplacian_array.var()

    async def extract_face(self, image, detector, size=(160, 160)):
        """Detect and extract the face from an image object."""
        try:
            boxes, _ = self.detector.detect(np.array(image))
            if boxes is None or len(boxes) == 0:
                print("No face detected in the image.")
                return None

            x, y, width, height = boxes[0]
            x, y, width, height = int(x), int(y), int(width), int(height)

            x, y = max(x, 0), max(y, 0)
            face = image.crop((x, y, x + width, y + height))
            face = face.resize(size)
            return face
        except Exception as e:
            print(f"Error extracting face: {e}")
            traceback.print_exc()
            return None

    async def calculate_embedding(self, face_image):
        """Calculate the face embedding."""
        try:
            resnet = self.resnet
            face_tensor = torch.tensor(np.array(face_image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
            face_tensor = (face_tensor - 0.5) / 0.5  # Normalize to [-1, 1]
            face_tensor = face_tensor.to(self.device)
            embedding = resnet(face_tensor)
            return embedding.detach().cpu().numpy()
        except Exception as e:
            print(f"Error calculating embedding: {e}")
            traceback.print_exc()
            return None

    async def compare_faces(self, embedding1, embedding2):
        """Compare two face embeddings and return similarity."""
        distance = np.linalg.norm(embedding1 - embedding2)
        distance = round((distance), 2)
        print(f"Cosine Distance: {distance}")
        return distance < self.MVL

    async def get_image_from_url(self, img_path):
            """Download or load an image from a local path or URL."""
            try:
                # Check if the input is a URL
                if img_path.startswith("http://") or img_path.startswith("https://"):
                    async with ClientSession() as session:
                        async with session.get(img_path) as response:
                            if response.status != 200:
                                raise ValueError(f"HTTP Error: {response.status}")
                            img_data = await response.read()
                            image = Image.open(BytesIO(img_data))
                            image.verify()
                            image = Image.open(BytesIO(img_data)).convert("RGB")
                elif os.path.isfile(img_path):
                    # Load image from a local file
                    image = Image.open(img_path).convert("RGB")
                else:
                    raise ValueError("Invalid input path. Must be a valid URL or local file path.")
                
                blur = await self.calculate_blurriness(image)
                if blur and blur < 30:
                    print(f"===================>>>>>>>>>> Blurry: {blur}")
                    return None
                else:
                    print('Good Image: ', blur)
                
                return image
            except Exception as e:
                print(f"Error downloading or loading image: {e}")
                traceback.print_exc()
                return None

    async def compare_faces_from_urls(self, url1, url2):
        """Compare faces from two image URLs."""
        try:
            task = [
                asyncio.create_task(self.get_image_from_url(url1)),
                asyncio.create_task(self.get_image_from_url(url2))
            ]
            img1, img2 = await asyncio.gather(*task)

            if img1 is None or img2 is None:
                return False

            face1 = await self.extract_face(img1, self.detector)
            face2 = await self.extract_face(img2, self.detector)

            if face1 is None or face2 is None:
                return False

            embedding1 = await self.calculate_embedding(face1)
            embedding2 = await self.calculate_embedding(face2)

            if embedding1 is None or embedding2 is None:
                return False

            return await self.compare_faces(embedding1, embedding2)
        except Exception as e:
            print(f"Error comparing faces from URLs: {e}")
            traceback.print_exc()
            return False

    async def faceMain(self, BODY):
        try:
            result = await self.compare_faces_from_urls(BODY['img1'], BODY['img2'])
            if result:
                return True
            else:
                return False
        except Exception as e:
            traceback.print_exc()
            return False

# FastAPI server setup
app = FastAPI()

class Body(BaseModel):
    img1: str
    img2: str

# Initialize the face comparison class
face_executor = FaceMainExecutionPytorch()

@app.post("/compare_faces")
async def compare_faces(body: Body):
    try:
        result = await face_executor.faceMain(body.dict())
        print(result)
        if result:
            return True
        else:
            return False
            
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail="Internal Server Error")

if __name__ == "__main__":
    try:
        uvicorn.run(app, host="127.0.0.1", port=8888)
    finally:
        torch.cuda.empty_cache()  # Clear CUDA memory after the server stops