Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from facenet_pytorch import MTCNN, InceptionResnetV1 | |
| from PIL import Image | |
| import cv2 | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Header | |
| from fastapi.responses import JSONResponse | |
| from io import BytesIO | |
| app = FastAPI() | |
| API_KEY = "c50dd5ady0uRL0rdnSaVyrArYaN161edb06af8" | |
| def get_api_key(api_key: str = Header(...)): | |
| if api_key != API_KEY: | |
| raise HTTPException(status_code=403, detail="Could not validate credentials") | |
| DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| # Load MTCNN for face detection | |
| mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).eval() | |
| # Initialize the new model (InceptionResnetV1 with vggface2 and 3-class classification) | |
| model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=3, device=DEVICE) | |
| DESTINATION_FILE_PATH = 'Model/3Class_32_epoch.pth' | |
| # Load the model checkpoint | |
| checkpoint = torch.load(DESTINATION_FILE_PATH, map_location=torch.device('cpu')) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Prediction function using the new model | |
| def predict(input_image: Image.Image): | |
| """Predict the label of the input_image""" | |
| if input_image.mode == 'RGBA': | |
| input_image = input_image.convert('RGB') | |
| # Detect face in the image | |
| face = mtcnn(input_image) | |
| if face is None: | |
| raise Exception('No face detected') | |
| face = face.unsqueeze(0) # Add batch dimension | |
| face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) | |
| prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() | |
| prev_face = prev_face.astype('uint8') | |
| face = face.to(DEVICE).to(torch.float32) / 255.0 | |
| face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy() | |
| # Grad-CAM setup | |
| target_layers = [model.block8.branch1[-1]] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| targets = [ClassifierOutputTarget(0), ClassifierOutputTarget(1), ClassifierOutputTarget(2)] | |
| grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) | |
| grayscale_cam = grayscale_cam[0, :] | |
| visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) | |
| face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0) | |
| # Inference | |
| with torch.no_grad(): | |
| output = torch.softmax(model(face).squeeze(0), dim=0) | |
| class_indices = {0: 'real', 1: 'fake', 2: 'ai_generated'} | |
| prediction = class_indices[torch.argmax(output).item()] | |
| confidences = { | |
| 'real': output[0].item(), | |
| 'fake': output[1].item(), | |
| 'ai_generated': output[2].item() | |
| } | |
| return confidences, prediction, face_with_mask | |
| # FastAPI prediction endpoint | |
| async def predict_api(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): | |
| image = Image.open(BytesIO(await file.read())) | |
| try: | |
| confidences, prediction, face_with_mask = predict(image) | |
| _, buffer = cv2.imencode('.jpg', face_with_mask) | |
| face_with_mask_encoded = buffer.tobytes() | |
| return JSONResponse(content={ | |
| "confidences": confidences, | |
| "prediction": prediction, | |
| "face_with_mask": face_with_mask_encoded.hex() | |
| }) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=400) | |