Spaces:
Sleeping
Sleeping
| # Cell 1 | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from backend.keyframes.model import DSN | |
| import torch.nn as nn | |
| import cv2 | |
| import time | |
| import os | |
| import srt | |
| from backend.keyframes.extract_frames import extract_frames | |
| from backend.utils import copy_and_rename_file, get_black_bar_coordinates, crop_image | |
| import signal | |
| import threading # Added to check main thread | |
| # Cell 2 | |
| # Global model cache to avoid reloading | |
| _googlenet_model = None | |
| _preprocess_pipeline = None | |
| def _get_features(frames, gpu=True, batch_size=1): | |
| global _googlenet_model, _preprocess_pipeline | |
| # Load pre-trained GoogLeNet model only once | |
| if _googlenet_model is None: | |
| print("π Loading GoogLeNet model (this happens only once)...") | |
| _googlenet_model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', weights='GoogLeNet_Weights.DEFAULT') | |
| # Remove the classification layer (last layer) to obtain features | |
| _googlenet_model = torch.nn.Sequential(*(list(_googlenet_model.children())[:-1])) | |
| _googlenet_model.eval() | |
| # Initialize preprocessing pipeline | |
| _preprocess_pipeline = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Move to GPU if available | |
| if gpu: | |
| _googlenet_model.to('cuda') | |
| print("β GoogLeNet model loaded successfully") | |
| # Initialize a list to store the features | |
| features = [] | |
| # Iterate through frames | |
| for frame_path in frames: | |
| # Load and preprocess the frame | |
| input_image = Image.open(frame_path) | |
| input_tensor = _preprocess_pipeline(input_image) | |
| input_batch = input_tensor.unsqueeze(0) # Add batch dimension | |
| # Move the input to GPU if available | |
| if gpu: | |
| input_batch = input_batch.to('cuda') | |
| # Perform feature extraction | |
| with torch.no_grad(): | |
| output = _googlenet_model(input_batch) | |
| # Append the features to the list | |
| features.append(output.squeeze().cpu().numpy()) | |
| # Convert the list of features to a NumPy array | |
| features = np.array(features) | |
| return features.astype(np.float32) | |
| # Global DSN model cache | |
| _dsn_models = {} | |
| def _get_probs(features, gpu=True, mode=0): | |
| global _dsn_models | |
| # Create cache key | |
| cache_key = f"dsn_model_{mode}_{gpu}" | |
| # Load model only if not already cached | |
| if cache_key not in _dsn_models: | |
| print(f"π Loading DSN model {mode} (this happens only once)...") | |
| if mode == 1: | |
| model_path = "backend/keyframes/pretrained_model/model_1.pth.tar" | |
| else: | |
| model_path = "backend/keyframes/pretrained_model/model_0.pth.tar" | |
| model = DSN(in_dim=1024, hid_dim=256, num_layers=1, cell="lstm") | |
| if gpu: | |
| checkpoint = torch.load(model_path) | |
| else: | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(checkpoint) | |
| if gpu: | |
| model = nn.DataParallel(model).cuda() | |
| model.eval() | |
| _dsn_models[cache_key] = model | |
| print(f"β DSN model {mode} loaded successfully") | |
| model = _dsn_models[cache_key] | |
| seq = torch.from_numpy(features).unsqueeze(0) | |
| if gpu: seq = seq.cuda() | |
| probs = model(seq) | |
| probs = probs.data.cpu().squeeze().numpy() | |
| return probs | |
| def generate_keyframes(video): | |
| data="" | |
| with open("test1.srt") as f: | |
| data = f.read() | |
| subs = srt.parse(data) | |
| torch.cuda.empty_cache() | |
| # Add timeout protection | |
| def timeout_handler(signum, frame): | |
| raise TimeoutError("Keyframe generation timed out") | |
| # Set timeout to 10 minutes only if running in the main thread (signals are not allowed in worker threads) | |
| if threading.current_thread() is threading.main_thread(): | |
| signal.signal(signal.SIGALRM, timeout_handler) | |
| signal.alarm(600) # 10 minutes timeout | |
| # Create final directory if it doesn't exist | |
| final_dir = os.path.join("frames", "final") | |
| if not os.path.exists(final_dir): | |
| os.makedirs(final_dir) | |
| print(f"Created directory: {final_dir}") | |
| frame_counter = 1 | |
| total_subs = len(list(subs)) | |
| subs = list(subs) # Convert to list to avoid exhaustion | |
| print(f"π― Processing {total_subs} subtitle segments...") | |
| try: | |
| # Enhanced story-aware keyframe extraction | |
| for i, sub in enumerate(subs, 1): | |
| print(f"π Processing segment {i}/{total_subs}: {sub.content[:30]}...") | |
| frames = [] | |
| if not os.path.exists(f"frames/sub{sub.index}"): | |
| os.makedirs(f"frames/sub{sub.index}") | |
| # Extract more frames per segment for better story selection | |
| frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"), | |
| sub.start.total_seconds(), sub.end.total_seconds(), 10) # Increased from 3 to 10 | |
| if len(frames) > 0: | |
| # Get AI highlight scores | |
| features = _get_features(frames, gpu=False) | |
| highlight_scores = _get_probs(features, gpu=False) | |
| # Enhanced story-aware selection | |
| story_frames = _select_story_relevant_frames(frames, highlight_scores, sub) | |
| # Save the best story frames | |
| for j, frame_idx in enumerate(story_frames): | |
| if frame_counter <= 16: # Limit to 16 frames total | |
| try: | |
| copy_and_rename_file(frames[frame_idx], final_dir, f"frame{frame_counter:03}.png") | |
| print(f"π Story frame {frame_counter}: {sub.content} (score: {highlight_scores[frame_idx]:.3f})") | |
| frame_counter += 1 | |
| except: | |
| pass | |
| else: | |
| # Fallback if no frames extracted | |
| print(f"β οΈ No frames extracted for subtitle {sub.index}") | |
| # If no frames were successfully generated, run fallback extraction on full video | |
| if frame_counter == 1: | |
| print("π¨ No story-relevant frames generated β falling back to uniform extractionβ¦") | |
| try: | |
| # Extract 16 evenly spaced frames across the entire video duration | |
| video_cap = cv2.VideoCapture(video) | |
| total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| step = max(total_frames // 16, 1) | |
| extracted = 0 | |
| frame_idx = 0 | |
| while extracted < 16 and video_cap.isOpened(): | |
| video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = video_cap.read() | |
| if not ret: | |
| break | |
| out_path = os.path.join(final_dir, f"frame{frame_counter:03}.png") | |
| cv2.imwrite(out_path, frame) | |
| frame_counter += 1 | |
| extracted += 1 | |
| frame_idx += step | |
| video_cap.release() | |
| print(f"β Fallback extracted {extracted} uniform frames") | |
| except Exception as e: | |
| print(f"Fallback extraction failed: {e}") | |
| print(f"β Generated {frame_counter-1} story-relevant frames") | |
| except TimeoutError: | |
| print("β° Keyframe generation timed out, using fallback method...") | |
| # Fallback: use first few subtitle segments | |
| for i, sub in enumerate(subs[:4], 1): # Use only first 4 segments | |
| if frame_counter <= 16: | |
| try: | |
| # Simple frame extraction without AI | |
| frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"), | |
| sub.start.total_seconds(), sub.end.total_seconds(), 1) | |
| if frames: | |
| copy_and_rename_file(frames[0], final_dir, f"frame{frame_counter:03}.png") | |
| print(f"π Fallback frame {frame_counter}: {sub.content}") | |
| frame_counter += 1 | |
| except: | |
| pass | |
| print(f"β Generated {frame_counter-1} fallback frames") | |
| finally: | |
| # Cancel timeout | |
| signal.alarm(0) | |
| def _select_story_relevant_frames(frames, highlight_scores, subtitle): | |
| """Enhanced story-aware frame selection""" | |
| try: | |
| highlight_scores = list(highlight_scores) | |
| # 1. Get top AI-scored frames | |
| sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)] | |
| # 2. Analyze frames for story relevance | |
| story_scores = [] | |
| for i, frame_path in enumerate(frames): | |
| story_score = _analyze_story_relevance(frame_path, highlight_scores[i], subtitle) | |
| story_scores.append(story_score) | |
| # 3. Combine AI scores with story relevance | |
| combined_scores = [] | |
| for i in range(len(frames)): | |
| combined_score = (highlight_scores[i] * 0.6) + (story_scores[i] * 0.4) # 60% AI, 40% story | |
| combined_scores.append(combined_score) | |
| # 4. Select top frames based on combined scores | |
| sorted_combined = [i[0] for i in sorted(enumerate(combined_scores), key=lambda x: x[1], reverse=True)] | |
| # Return top 2-3 frames per segment for better story coverage | |
| num_frames_to_select = min(3, len(frames)) | |
| return sorted_combined[:num_frames_to_select] | |
| except Exception as e: | |
| print(f"Story selection failed: {e}") | |
| # Fallback to original method | |
| try: | |
| highlight_scores = list(highlight_scores) | |
| sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)] | |
| return [sorted_indices[0]] if sorted_indices else [0] | |
| except: | |
| return [0] # Ultimate fallback | |
| def _analyze_story_relevance(frame_path, ai_score, subtitle): | |
| """Analyze frame for story relevance""" | |
| try: | |
| img = cv2.imread(frame_path) | |
| if img is None: | |
| return ai_score | |
| # 1. Face detection (dialogue scenes are important) | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) | |
| face_score = len(faces) * 0.2 # Bonus for faces | |
| # 2. Motion/action detection | |
| motion_score = _detect_motion(img) * 0.15 | |
| # 3. Scene complexity (more complex scenes might be more important) | |
| complexity_score = _analyze_scene_complexity(img) * 0.1 | |
| # 4. Subtitle content analysis | |
| content_score = _analyze_subtitle_relevance(subtitle.content) * 0.15 | |
| # Combine scores | |
| story_score = ai_score + face_score + motion_score + complexity_score + content_score | |
| return min(story_score, 1.0) # Cap at 1.0 | |
| except Exception as e: | |
| return ai_score # Fallback to AI score | |
| def _detect_motion(img): | |
| """Detect motion/action in frame""" | |
| try: | |
| # Simple edge density as motion indicator | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| edges = cv2.Canny(gray, 50, 150) | |
| edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1]) | |
| return min(edge_density * 10, 1.0) # Normalize to 0-1 | |
| except: | |
| return 0.0 | |
| def _analyze_scene_complexity(img): | |
| """Analyze scene complexity""" | |
| try: | |
| # Use color variance as complexity indicator | |
| lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) | |
| l_channel = lab[:,:,0] | |
| complexity = np.std(l_channel) / 255.0 | |
| return min(complexity * 2, 1.0) # Normalize to 0-1 | |
| except: | |
| return 0.0 | |
| def _analyze_subtitle_relevance(subtitle_text): | |
| """Analyze subtitle content for story relevance""" | |
| # Keywords that indicate important story moments | |
| important_keywords = [ | |
| 'hello', 'goodbye', 'thank', 'please', 'sorry', 'yes', 'no', | |
| 'love', 'hate', 'help', 'danger', 'important', 'secret', | |
| 'action', 'fight', 'run', 'stop', 'go', 'come', 'leave' | |
| ] | |
| text_lower = subtitle_text.lower() | |
| relevance_score = 0.0 | |
| for keyword in important_keywords: | |
| if keyword in text_lower: | |
| relevance_score += 0.1 | |
| return min(relevance_score, 1.0) # Cap at 1.0 | |
| def black_bar_crop(): | |
| ref_img_path = "frames/final/frame001.png" | |
| # Check if reference image exists | |
| if not os.path.exists(ref_img_path): | |
| print(f"β Reference image not found: {ref_img_path}") | |
| return 0, 0, 0, 0 | |
| x, y, w, h = get_black_bar_coordinates(ref_img_path) | |
| # Loop through each keyframe | |
| folder_dir = "frames/final" | |
| if not os.path.exists(folder_dir): | |
| print(f"β Frames directory not found: {folder_dir}") | |
| return x, y, w, h | |
| for image in os.listdir(folder_dir): | |
| img_path = os.path.join("frames",'final',image) | |
| if os.path.exists(img_path): | |
| image_data = cv2.imread(img_path) | |
| if image_data is not None: | |
| # Crop the image | |
| crop = image_data[y:y+h, x:x+w] | |
| # Save the cropped image | |
| cv2.imwrite(img_path, crop) | |
| return x, y, w, h |