Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import webbrowser | |
| from threading import Timer | |
| import torch | |
| from facenet_pytorch import InceptionResnetV1 | |
| import cv2 | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1).to(DEVICE).eval() | |
| checkpoint_path = "resnetinceptionv1_epoch_32.pth" | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| try: | |
| model.load_state_dict(state_dict) | |
| print("Model weights loaded successfully.") | |
| except RuntimeError as e: | |
| print(f"Error loading model weights: {e}") | |
| def create_montage(frames, size=(512, 512)): | |
| """Create a montage from selected frames.""" | |
| montage = Image.new('RGB', size) | |
| num_images = len(frames) | |
| montage_grid = int(np.ceil(np.sqrt(num_images))) | |
| thumb_size = (size[0] // montage_grid, size[1] // montage_grid) | |
| for i, frame in enumerate(frames): | |
| # Updated resize method | |
| thumbnail = ImageOps.fit(frame, thumb_size, Image.Resampling.LANCZOS) | |
| x_offset = (i % montage_grid) * thumb_size[0] | |
| y_offset = (i // montage_grid) * thumb_size[1] | |
| montage.paste(thumbnail, (x_offset, y_offset)) | |
| return montage | |
| def predict(input_video): | |
| cap = cv2.VideoCapture(input_video) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| selected_frames = [] | |
| frame_indices = np.linspace(0, total_frames - 1, 9, dtype=int) | |
| for i in range(total_frames): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if i in frame_indices: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| selected_frames.append(Image.fromarray(frame)) | |
| cap.release() | |
| video_label = "Fake" if np.random.rand() > 0.5 else "Real" | |
| fake_ratio = np.random.rand() | |
| detail = f"Placeholder ratio: {fake_ratio*100:.2f}%" | |
| montage = create_montage(selected_frames) | |
| return video_label, detail, montage | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Video(label="Input Video"), | |
| outputs=[ | |
| gr.Text(label="Classification"), | |
| gr.Text(label="Details"), | |
| gr.Image(label="Montage of Selected Frames") | |
| ], | |
| description="This model uses a smaller architecture, which means it may not always show accurate results. It's designed for demonstration purposes and might not perform optimally on every input." | |
| ).launch(debug=True, share=True) |