Test / backend /keyframes /keyframes.py
3v324v23's picture
Update Comic123 with local comic folder files
83e35a7
# Cell 1
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from backend.keyframes.model import DSN
import torch.nn as nn
import cv2
import time
import os
import srt
from backend.keyframes.extract_frames import extract_frames
from backend.utils import copy_and_rename_file, get_black_bar_coordinates, crop_image
import signal
import threading # Added to check main thread
# Cell 2
# Global model cache to avoid reloading
_googlenet_model = None
_preprocess_pipeline = None
def _get_features(frames, gpu=True, batch_size=1):
global _googlenet_model, _preprocess_pipeline
# Load pre-trained GoogLeNet model only once
if _googlenet_model is None:
print("🔄 Loading GoogLeNet model (this happens only once)...")
_googlenet_model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', weights='GoogLeNet_Weights.DEFAULT')
# Remove the classification layer (last layer) to obtain features
_googlenet_model = torch.nn.Sequential(*(list(_googlenet_model.children())[:-1]))
_googlenet_model.eval()
# Initialize preprocessing pipeline
_preprocess_pipeline = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Move to GPU if available
if gpu:
_googlenet_model.to('cuda')
print("✅ GoogLeNet model loaded successfully")
# Initialize a list to store the features
features = []
# Iterate through frames
for frame_path in frames:
# Load and preprocess the frame
input_image = Image.open(frame_path)
input_tensor = _preprocess_pipeline(input_image)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
# Move the input to GPU if available
if gpu:
input_batch = input_batch.to('cuda')
# Perform feature extraction
with torch.no_grad():
output = _googlenet_model(input_batch)
# Append the features to the list
features.append(output.squeeze().cpu().numpy())
# Convert the list of features to a NumPy array
features = np.array(features)
return features.astype(np.float32)
# Global DSN model cache
_dsn_models = {}
def _get_probs(features, gpu=True, mode=0):
global _dsn_models
# Create cache key
cache_key = f"dsn_model_{mode}_{gpu}"
# Load model only if not already cached
if cache_key not in _dsn_models:
print(f"🔄 Loading DSN model {mode} (this happens only once)...")
if mode == 1:
model_path = "backend/keyframes/pretrained_model/model_1.pth.tar"
else:
model_path = "backend/keyframes/pretrained_model/model_0.pth.tar"
model = DSN(in_dim=1024, hid_dim=256, num_layers=1, cell="lstm")
if gpu:
checkpoint = torch.load(model_path)
else:
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint)
if gpu:
model = nn.DataParallel(model).cuda()
model.eval()
_dsn_models[cache_key] = model
print(f"✅ DSN model {mode} loaded successfully")
model = _dsn_models[cache_key]
seq = torch.from_numpy(features).unsqueeze(0)
if gpu: seq = seq.cuda()
probs = model(seq)
probs = probs.data.cpu().squeeze().numpy()
return probs
def generate_keyframes(video):
data=""
with open("test1.srt") as f:
data = f.read()
subs = srt.parse(data)
torch.cuda.empty_cache()
# Add timeout protection
def timeout_handler(signum, frame):
raise TimeoutError("Keyframe generation timed out")
# Set timeout to 10 minutes only if running in the main thread (signals are not allowed in worker threads)
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(600) # 10 minutes timeout
# Create final directory if it doesn't exist
final_dir = os.path.join("frames", "final")
if not os.path.exists(final_dir):
os.makedirs(final_dir)
print(f"Created directory: {final_dir}")
frame_counter = 1
total_subs = len(list(subs))
subs = list(subs) # Convert to list to avoid exhaustion
print(f"🎯 Processing {total_subs} subtitle segments...")
try:
# Enhanced story-aware keyframe extraction
for i, sub in enumerate(subs, 1):
print(f"📝 Processing segment {i}/{total_subs}: {sub.content[:30]}...")
frames = []
if not os.path.exists(f"frames/sub{sub.index}"):
os.makedirs(f"frames/sub{sub.index}")
# Extract more frames per segment for better story selection
frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"),
sub.start.total_seconds(), sub.end.total_seconds(), 10) # Increased from 3 to 10
if len(frames) > 0:
# Get AI highlight scores
features = _get_features(frames, gpu=False)
highlight_scores = _get_probs(features, gpu=False)
# Enhanced story-aware selection
story_frames = _select_story_relevant_frames(frames, highlight_scores, sub)
# Save the best story frames
for j, frame_idx in enumerate(story_frames):
if frame_counter <= 16: # Limit to 16 frames total
try:
copy_and_rename_file(frames[frame_idx], final_dir, f"frame{frame_counter:03}.png")
print(f"📖 Story frame {frame_counter}: {sub.content} (score: {highlight_scores[frame_idx]:.3f})")
frame_counter += 1
except:
pass
else:
# Fallback if no frames extracted
print(f"⚠️ No frames extracted for subtitle {sub.index}")
# If no frames were successfully generated, run fallback extraction on full video
if frame_counter == 1:
print("🚨 No story-relevant frames generated – falling back to uniform extraction…")
try:
# Extract 16 evenly spaced frames across the entire video duration
video_cap = cv2.VideoCapture(video)
total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(total_frames // 16, 1)
extracted = 0
frame_idx = 0
while extracted < 16 and video_cap.isOpened():
video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = video_cap.read()
if not ret:
break
out_path = os.path.join(final_dir, f"frame{frame_counter:03}.png")
cv2.imwrite(out_path, frame)
frame_counter += 1
extracted += 1
frame_idx += step
video_cap.release()
print(f"✅ Fallback extracted {extracted} uniform frames")
except Exception as e:
print(f"Fallback extraction failed: {e}")
print(f"✅ Generated {frame_counter-1} story-relevant frames")
except TimeoutError:
print("⏰ Keyframe generation timed out, using fallback method...")
# Fallback: use first few subtitle segments
for i, sub in enumerate(subs[:4], 1): # Use only first 4 segments
if frame_counter <= 16:
try:
# Simple frame extraction without AI
frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"),
sub.start.total_seconds(), sub.end.total_seconds(), 1)
if frames:
copy_and_rename_file(frames[0], final_dir, f"frame{frame_counter:03}.png")
print(f"📖 Fallback frame {frame_counter}: {sub.content}")
frame_counter += 1
except:
pass
print(f"✅ Generated {frame_counter-1} fallback frames")
finally:
# Cancel timeout
signal.alarm(0)
def _select_story_relevant_frames(frames, highlight_scores, subtitle):
"""Enhanced story-aware frame selection"""
try:
highlight_scores = list(highlight_scores)
# 1. Get top AI-scored frames
sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)]
# 2. Analyze frames for story relevance
story_scores = []
for i, frame_path in enumerate(frames):
story_score = _analyze_story_relevance(frame_path, highlight_scores[i], subtitle)
story_scores.append(story_score)
# 3. Combine AI scores with story relevance
combined_scores = []
for i in range(len(frames)):
combined_score = (highlight_scores[i] * 0.6) + (story_scores[i] * 0.4) # 60% AI, 40% story
combined_scores.append(combined_score)
# 4. Select top frames based on combined scores
sorted_combined = [i[0] for i in sorted(enumerate(combined_scores), key=lambda x: x[1], reverse=True)]
# Return top 2-3 frames per segment for better story coverage
num_frames_to_select = min(3, len(frames))
return sorted_combined[:num_frames_to_select]
except Exception as e:
print(f"Story selection failed: {e}")
# Fallback to original method
try:
highlight_scores = list(highlight_scores)
sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)]
return [sorted_indices[0]] if sorted_indices else [0]
except:
return [0] # Ultimate fallback
def _analyze_story_relevance(frame_path, ai_score, subtitle):
"""Analyze frame for story relevance"""
try:
img = cv2.imread(frame_path)
if img is None:
return ai_score
# 1. Face detection (dialogue scenes are important)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
face_score = len(faces) * 0.2 # Bonus for faces
# 2. Motion/action detection
motion_score = _detect_motion(img) * 0.15
# 3. Scene complexity (more complex scenes might be more important)
complexity_score = _analyze_scene_complexity(img) * 0.1
# 4. Subtitle content analysis
content_score = _analyze_subtitle_relevance(subtitle.content) * 0.15
# Combine scores
story_score = ai_score + face_score + motion_score + complexity_score + content_score
return min(story_score, 1.0) # Cap at 1.0
except Exception as e:
return ai_score # Fallback to AI score
def _detect_motion(img):
"""Detect motion/action in frame"""
try:
# Simple edge density as motion indicator
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150)
edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
return min(edge_density * 10, 1.0) # Normalize to 0-1
except:
return 0.0
def _analyze_scene_complexity(img):
"""Analyze scene complexity"""
try:
# Use color variance as complexity indicator
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l_channel = lab[:,:,0]
complexity = np.std(l_channel) / 255.0
return min(complexity * 2, 1.0) # Normalize to 0-1
except:
return 0.0
def _analyze_subtitle_relevance(subtitle_text):
"""Analyze subtitle content for story relevance"""
# Keywords that indicate important story moments
important_keywords = [
'hello', 'goodbye', 'thank', 'please', 'sorry', 'yes', 'no',
'love', 'hate', 'help', 'danger', 'important', 'secret',
'action', 'fight', 'run', 'stop', 'go', 'come', 'leave'
]
text_lower = subtitle_text.lower()
relevance_score = 0.0
for keyword in important_keywords:
if keyword in text_lower:
relevance_score += 0.1
return min(relevance_score, 1.0) # Cap at 1.0
def black_bar_crop():
ref_img_path = "frames/final/frame001.png"
# Check if reference image exists
if not os.path.exists(ref_img_path):
print(f"❌ Reference image not found: {ref_img_path}")
return 0, 0, 0, 0
x, y, w, h = get_black_bar_coordinates(ref_img_path)
# Loop through each keyframe
folder_dir = "frames/final"
if not os.path.exists(folder_dir):
print(f"❌ Frames directory not found: {folder_dir}")
return x, y, w, h
for image in os.listdir(folder_dir):
img_path = os.path.join("frames",'final',image)
if os.path.exists(img_path):
image_data = cv2.imread(img_path)
if image_data is not None:
# Crop the image
crop = image_data[y:y+h, x:x+w]
# Save the cropped image
cv2.imwrite(img_path, crop)
return x, y, w, h