deepsight-api / app.py
tayyabimam's picture
Update app.py
fb0dd7f verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import os
import time
import shutil
import glob
import datetime
from random import choice
import torch
import torchvision
from torchvision import transforms
from torch import nn
import numpy as np
import cv2
import face_recognition
from PIL import Image as pImage
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # Use non-GUI backend for matplotlib
from typing import List
import base64
import io
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create directories if they don't exist
os.makedirs("uploaded_images", exist_ok=True)
os.makedirs("static", exist_ok=True)
# Mount static files
app.mount("/uploaded_images", StaticFiles(directory="uploaded_images"), name="uploaded_images")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Configuration
im_size = 112
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
sm = nn.Softmax(dim=1)
inv_normalize = transforms.Normalize(
mean=-1*np.divide(mean, std), std=np.divide([1, 1, 1], std))
train_transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
ALLOWED_VIDEO_EXTENSIONS = {'mp4', 'gif', 'webm', 'avi', '3gp', 'wmv', 'flv', 'mkv'}
# Detects GPU in device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model(nn.Module):
def __init__(self, num_classes, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
super(Model, self).__init__()
model = torchvision.models.resnext50_32x4d(weights=torchvision.models.ResNeXt50_32X4D_Weights.DEFAULT)
self.model = nn.Sequential(*list(model.children())[:-2])
self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
self.relu = nn.LeakyReLU()
self.dp = nn.Dropout(0.4)
self.linear1 = nn.Linear(2048, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
batch_size, seq_length, c, h, w = x.shape
x = x.view(batch_size * seq_length, c, h, w)
fmap = self.model(x)
x = self.avgpool(fmap)
x = x.view(batch_size, seq_length, 2048)
x_lstm, _ = self.lstm(x, None)
return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
class ValidationDataset(torch.utils.data.Dataset):
def __init__(self, video_names, sequence_length=60, transform=None):
self.video_names = video_names
self.transform = transform
self.count = sequence_length
def __len__(self):
return len(self.video_names)
def __getitem__(self, idx):
video_path = self.video_names[idx]
frames = []
a = int(100/self.count)
first_frame = np.random.randint(0, a)
for i, frame in enumerate(self.frame_extract(video_path)):
faces = face_recognition.face_locations(frame)
try:
top, right, bottom, left = faces[0]
frame = frame[top:bottom, left:right, :]
except:
pass
frames.append(self.transform(frame))
if (len(frames) == self.count):
break
frames = torch.stack(frames)
frames = frames[:self.count]
return frames.unsqueeze(0) # Shape: (1, seq_len, C, H, W)
def frame_extract(self, path):
vidObj = cv2.VideoCapture(path)
success = 1
while success:
success, image = vidObj.read()
if success:
yield image
def allowed_video_file(filename):
return filename.split('.')[-1].lower() in ALLOWED_VIDEO_EXTENSIONS
def get_accurate_model(sequence_length):
model_name = []
sequence_model = []
final_model = ""
# Create models directory if it doesn't exist
os.makedirs("models", exist_ok=True)
list_models = glob.glob(os.path.join("models", "*.pt"))
for i in list_models:
model_name.append(os.path.basename(i))
for i in model_name:
try:
seq = i.split("_")[3]
if (int(seq) == sequence_length):
sequence_model.append(i)
except:
pass
if len(sequence_model) > 1:
accuracy = []
for i in sequence_model:
acc = i.split("_")[1]
accuracy.append(acc)
max_index = accuracy.index(max(accuracy))
final_model = sequence_model[max_index]
else:
final_model = sequence_model[0] if sequence_model else None
return final_model
def im_convert(tensor, video_file_name=""):
"""Convert tensor to image for visualization."""
image = tensor.to("cpu").clone().detach()
image = image.squeeze()
image = inv_normalize(image)
image = image.numpy()
image = image.transpose(1, 2, 0)
image = image.clip(0, 1)
return image
def generate_gradcam_heatmap(model, img, video_file_name=""):
"""Generate GradCAM heatmap showing areas of focus for deepfake detection."""
fmap, logits = model(img)
logits_softmax = sm(logits)
confidence, prediction = torch.max(logits_softmax, 1)
confidence_val = confidence.item() * 100
pred_idx = prediction.item()
weight_softmax = model.linear1.weight.detach().cpu().numpy()
fmap_last = fmap[-1].detach().cpu().numpy()
nc, h, w = fmap_last.shape
fmap_reshaped = fmap_last.reshape(nc, h*w)
heatmap_raw = np.dot(fmap_reshaped.T, weight_softmax[pred_idx, :].T)
heatmap_raw -= heatmap_raw.min()
heatmap_raw /= heatmap_raw.max()
heatmap_img = np.uint8(255 * heatmap_raw.reshape(h, w))
heatmap_resized = cv2.resize(heatmap_img, (im_size, im_size))
heatmap_colored = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
original_img = im_convert(img[:, -1, :, :, :])
original_img_uint8 = (original_img * 255).astype(np.uint8)
overlay = cv2.addWeighted(original_img_uint8, 0.6, heatmap_colored, 0.4, 0)
os.makedirs(os.path.join("static", "heatmaps"), exist_ok=True)
heatmap_filename = f"{video_file_name}_heatmap_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
heatmap_path = os.path.join("static", "heatmaps", heatmap_filename)
cv2.imwrite(heatmap_path, overlay)
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(original_img)
plt.title('Original Frame')
plt.axis('on')
plt.subplot(1, 3, 2)
plt.imshow(heatmap_resized, cmap='jet')
plt.title('Attention Heatmap')
plt.axis('on')
plt.subplot(1, 3, 3)
plt.imshow(overlay[..., ::-1])
plt.title(f'Overlay - Prediction: {"REAL" if pred_idx == 1 else "FAKE"} ({confidence_val:.1f}%)')
plt.axis('on')
plt.tight_layout()
plt_filename = f"{video_file_name}_analysis_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
plt_path = os.path.join("static", "heatmaps", plt_filename)
plt.savefig(plt_path, dpi=150, bbox_inches='tight')
plt.close()
return {
'prediction': pred_idx,
'confidence': confidence_val,
'heatmap_path': f"/static/heatmaps/{heatmap_filename}",
'analysis_path': f"/static/heatmaps/{plt_filename}"
}
def predict_with_gradcam(model, img, video_file_name=""):
return generate_gradcam_heatmap(model, img, video_file_name)
@app.post("/api/upload")
async def api_upload_video(file: UploadFile = File(...), sequence_length: int = 20):
if not allowed_video_file(file.filename):
raise HTTPException(status_code=400, detail="Only video files are allowed")
file_ext = file.filename.split('.')[-1]
saved_video_file = f'uploaded_video_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.{file_ext}'
os.makedirs("uploaded_videos", exist_ok=True)
file_path = os.path.join("uploaded_videos", saved_video_file)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
result = await process_video(file_path, sequence_length)
return {
"status": "success",
"result": result["output"],
"confidence": result["confidence"],
"accuracy": result["accuracy"],
"frames_processed": sequence_length,
"preprocessed_images": result["preprocessed_images"],
"faces_cropped_images": result["faces_cropped_images"],
"heatmap_image": result["heatmap_image"],
"analysis_image": result["analysis_image"],
"gradcam_explanation": result["gradcam_explanation"]
}
async def process_video(video_file, sequence_length):
try:
if not os.path.exists(video_file):
raise HTTPException(status_code=400, detail="Video file not found")
path_to_videos = [video_file]
video_file_name = os.path.basename(video_file)
video_file_name_only = os.path.splitext(video_file_name)[0]
video_dataset = ValidationDataset(
path_to_videos, sequence_length=sequence_length, transform=train_transforms)
model = Model(2).to(device)
model_filename = get_accurate_model(sequence_length)
if not model_filename:
raise HTTPException(
status_code=500,
detail=f"No suitable model found for sequence length {sequence_length}"
)
model_path = os.path.join("models", model_filename)
if not os.path.exists(model_path):
raise HTTPException(
status_code=500,
detail=f"Model file not found at {model_path}"
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
cap = cv2.VideoCapture(video_file)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
break
cap.release()
if not frames:
raise HTTPException(status_code=400, detail="No frames could be read from the video")
os.makedirs(os.path.join("static", "uploaded_images"), exist_ok=True)
preprocessed_images = []
for i in range(1, min(sequence_length + 1, len(frames))):
try:
frame = frames[i]
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = pImage.fromarray(image, 'RGB')
image_name = f"{video_file_name_only}_preprocessed_{i}.png"
image_path = os.path.join("static", "uploaded_images", image_name)
img.save(image_path)
preprocessed_images.append(f"/static/uploaded_images/{image_name}")
except Exception as e:
print(f"Error processing frame {i}: {str(e)}")
continue
padding = 40
faces_cropped_images = []
faces_found = 0
for i in range(1, min(sequence_length + 1, len(frames))):
try:
frame = frames[i]
face_locations = face_recognition.face_locations(frame)
if not face_locations:
continue
top, right, bottom, left = face_locations[0]
frame_face = frame[
max(0, top-padding):min(frame.shape[0], bottom+padding),
max(0, left-padding):min(frame.shape[1], right+padding)
]
image = cv2.cvtColor(frame_face, cv2.COLOR_BGR2RGB)
img = pImage.fromarray(image, 'RGB')
image_name = f"{video_file_name_only}_cropped_faces_{i}.png"
image_path = os.path.join("static", "uploaded_images", image_name)
img.save(image_path)
faces_found += 1
faces_cropped_images.append(f"/static/uploaded_images/{image_name}")
except Exception as e:
print(f"Error processing face in frame {i}: {str(e)}")
continue
if faces_found == 0:
raise HTTPException(status_code=400, detail="No faces detected in the video")
try:
input_tensor = video_dataset[0].to(device)
gradcam_result = predict_with_gradcam(model, input_tensor, video_file_name_only)
confidence = round(gradcam_result['confidence'], 1)
output = "REAL" if gradcam_result['prediction'] == 1 else "FAKE"
try:
accuracy = model_filename.split("_")[1] if len(model_filename.split("_")) > 1 else "00"
decimal = model_filename.split("_")[2] if len(model_filename.split("_")) > 2 else "00"
except:
accuracy = "00"
decimal = "00"
gradcam_explanation = {
"description": "The heatmap shows areas where the AI model focused its attention when making the prediction.",
"interpretation": {
"red_areas": "High attention - areas that strongly influenced the decision",
"yellow_areas": "Medium attention - moderately important areas",
"blue_areas": "Low attention - areas with minimal influence on the decision"
},
"prediction_basis": f"The model classified this video as {output} with {confidence}% confidence based on the highlighted facial regions."
}
return {
"preprocessed_images": preprocessed_images,
"faces_cropped_images": faces_cropped_images,
"output": output,
"confidence": confidence,
"accuracy": accuracy,
"decimal": decimal,
"heatmap_image": gradcam_result['heatmap_path'],
"analysis_image": gradcam_result['analysis_path'],
"gradcam_explanation": gradcam_explanation
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error making prediction: {str(e)}"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error processing video: {str(e)}"
)
@app.post("/predict")
async def predict_frames(data: dict):
try:
print("Received request to /predict endpoint")
frames = data.get('frames', [])
if not frames:
print("No frames provided in request")
raise HTTPException(status_code=400, detail="No frames provided")
print(f"Processing {len(frames)} frames")
sequence_length = 20
processed_frames = []
for i, frame_base64 in enumerate(frames[:sequence_length]):
try:
if ',' in frame_base64:
frame_base64 = frame_base64.split(',')[1]
frame_data = base64.b64decode(frame_base64)
frame = cv2.imdecode(
np.frombuffer(frame_data, np.uint8),
cv2.IMREAD_COLOR
)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
try:
faces = face_recognition.face_locations(frame)
if faces:
top, right, bottom, left = faces[0]
height, width = frame.shape[:2]
margin = int(min(width, height) * 0.1)
top = max(0, top - margin)
bottom = min(height, bottom + margin)
left = max(0, left - margin)
right = min(width, right + margin)
frame = frame[top:bottom, left:right, :]
print(f"Face detected in frame {i+1} with margins")
else:
print(f"No face detected in frame {i+1}, using full frame")
except Exception as e:
print(f"Face detection error in frame {i+1}: {str(e)}, using full frame")
height, width = frame.shape[:2]
max_dimension = 512
if height > max_dimension or width > max_dimension:
scale = max_dimension / max(height, width)
new_width = int(width * scale)
new_height = int(height * scale)
frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
print(f"Resized frame {i+1} to {new_width}x{new_height}")
processed_frames.append(frame)
except Exception as e:
print(f"Error processing frame {i+1}: {str(e)}")
continue
if not processed_frames:
print("No valid frames could be processed")
raise HTTPException(status_code=400, detail="No valid frames could be processed")
print(f"Successfully processed {len(processed_frames)} frames")
frames_tensor = torch.stack([
train_transforms(frame) for frame in processed_frames
])
frames_tensor = frames_tensor.unsqueeze(0)
model = Model(2).cpu()
model_filename = get_accurate_model(sequence_length)
if not model_filename:
print(f"No suitable model found for sequence length {sequence_length}")
raise HTTPException(
status_code=500,
detail=f"No suitable model found for sequence length {sequence_length}"
)
print(f"Using model: {model_filename}")
try:
parts = model_filename.split('_')
accuracy = float(parts[1])
print(f"Extracted accuracy: {accuracy}%")
if accuracy <= 0 or accuracy > 100:
print("Invalid accuracy value, using default")
accuracy = 87.0
except Exception as e:
print(f"Error extracting accuracy: {str(e)}")
accuracy = 87.0
print(f"Using default accuracy: {accuracy}%")
model_path = os.path.join("models", model_filename)
print(f"Loading model from: {model_path}")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
with torch.no_grad():
_, logits = model(frames_tensor)
probabilities = sm(logits)
_, prediction = torch.max(probabilities, 1)
confidence = probabilities[:, int(prediction.item())].item() * 100
is_fake = prediction.item() == 0
print(f"Prediction: {'FAKE' if is_fake else 'REAL'} with {confidence:.2f}% confidence")
print(f"Model accuracy: {accuracy}%")
response_data = {
"is_fake": is_fake,
"confidence": confidence,
"frames_processed": len(processed_frames),
"model_accuracy": accuracy
}
print(f"Sending response: {response_data}")
return response_data
except Exception as e:
print(f"Error in predict_frames: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/test")
def test_endpoint():
return {"status": "success", "message": "API is working!"}