|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from backend.keyframes.model import DSN |
|
|
import torch.nn as nn |
|
|
import cv2 |
|
|
import time |
|
|
import os |
|
|
import srt |
|
|
from backend.keyframes.extract_frames import extract_frames |
|
|
from backend.utils import copy_and_rename_file, get_black_bar_coordinates, crop_image |
|
|
import signal |
|
|
import threading |
|
|
|
|
|
|
|
|
|
|
|
_googlenet_model = None |
|
|
_preprocess_pipeline = None |
|
|
|
|
|
def _get_features(frames, gpu=True, batch_size=1): |
|
|
global _googlenet_model, _preprocess_pipeline |
|
|
|
|
|
|
|
|
if _googlenet_model is None: |
|
|
print("🔄 Loading GoogLeNet model (this happens only once)...") |
|
|
_googlenet_model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', weights='GoogLeNet_Weights.DEFAULT') |
|
|
|
|
|
_googlenet_model = torch.nn.Sequential(*(list(_googlenet_model.children())[:-1])) |
|
|
_googlenet_model.eval() |
|
|
|
|
|
|
|
|
_preprocess_pipeline = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
|
|
|
if gpu: |
|
|
_googlenet_model.to('cuda') |
|
|
print("✅ GoogLeNet model loaded successfully") |
|
|
|
|
|
|
|
|
features = [] |
|
|
|
|
|
|
|
|
for frame_path in frames: |
|
|
|
|
|
input_image = Image.open(frame_path) |
|
|
input_tensor = _preprocess_pipeline(input_image) |
|
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
if gpu: |
|
|
input_batch = input_batch.to('cuda') |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = _googlenet_model(input_batch) |
|
|
|
|
|
|
|
|
features.append(output.squeeze().cpu().numpy()) |
|
|
|
|
|
|
|
|
features = np.array(features) |
|
|
|
|
|
return features.astype(np.float32) |
|
|
|
|
|
|
|
|
_dsn_models = {} |
|
|
|
|
|
def _get_probs(features, gpu=True, mode=0): |
|
|
global _dsn_models |
|
|
|
|
|
|
|
|
cache_key = f"dsn_model_{mode}_{gpu}" |
|
|
|
|
|
|
|
|
if cache_key not in _dsn_models: |
|
|
print(f"🔄 Loading DSN model {mode} (this happens only once)...") |
|
|
|
|
|
if mode == 1: |
|
|
model_path = "backend/keyframes/pretrained_model/model_1.pth.tar" |
|
|
else: |
|
|
model_path = "backend/keyframes/pretrained_model/model_0.pth.tar" |
|
|
|
|
|
model = DSN(in_dim=1024, hid_dim=256, num_layers=1, cell="lstm") |
|
|
|
|
|
if gpu: |
|
|
checkpoint = torch.load(model_path) |
|
|
else: |
|
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
if gpu: |
|
|
model = nn.DataParallel(model).cuda() |
|
|
|
|
|
model.eval() |
|
|
_dsn_models[cache_key] = model |
|
|
print(f"✅ DSN model {mode} loaded successfully") |
|
|
|
|
|
model = _dsn_models[cache_key] |
|
|
seq = torch.from_numpy(features).unsqueeze(0) |
|
|
if gpu: seq = seq.cuda() |
|
|
probs = model(seq) |
|
|
probs = probs.data.cpu().squeeze().numpy() |
|
|
return probs |
|
|
|
|
|
|
|
|
|
|
|
def generate_keyframes(video): |
|
|
data="" |
|
|
with open("test1.srt") as f: |
|
|
data = f.read() |
|
|
|
|
|
subs = srt.parse(data) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
def timeout_handler(signum, frame): |
|
|
raise TimeoutError("Keyframe generation timed out") |
|
|
|
|
|
|
|
|
if threading.current_thread() is threading.main_thread(): |
|
|
signal.signal(signal.SIGALRM, timeout_handler) |
|
|
signal.alarm(600) |
|
|
|
|
|
|
|
|
final_dir = os.path.join("frames", "final") |
|
|
if not os.path.exists(final_dir): |
|
|
os.makedirs(final_dir) |
|
|
print(f"Created directory: {final_dir}") |
|
|
|
|
|
frame_counter = 1 |
|
|
total_subs = len(list(subs)) |
|
|
subs = list(subs) |
|
|
|
|
|
print(f"🎯 Processing {total_subs} subtitle segments...") |
|
|
|
|
|
try: |
|
|
|
|
|
for i, sub in enumerate(subs, 1): |
|
|
print(f"📝 Processing segment {i}/{total_subs}: {sub.content[:30]}...") |
|
|
frames = [] |
|
|
if not os.path.exists(f"frames/sub{sub.index}"): |
|
|
os.makedirs(f"frames/sub{sub.index}") |
|
|
|
|
|
|
|
|
frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"), |
|
|
sub.start.total_seconds(), sub.end.total_seconds(), 10) |
|
|
|
|
|
if len(frames) > 0: |
|
|
|
|
|
features = _get_features(frames, gpu=False) |
|
|
highlight_scores = _get_probs(features, gpu=False) |
|
|
|
|
|
|
|
|
story_frames = _select_story_relevant_frames(frames, highlight_scores, sub) |
|
|
|
|
|
|
|
|
for j, frame_idx in enumerate(story_frames): |
|
|
if frame_counter <= 16: |
|
|
try: |
|
|
copy_and_rename_file(frames[frame_idx], final_dir, f"frame{frame_counter:03}.png") |
|
|
print(f"📖 Story frame {frame_counter}: {sub.content} (score: {highlight_scores[frame_idx]:.3f})") |
|
|
frame_counter += 1 |
|
|
except: |
|
|
pass |
|
|
else: |
|
|
|
|
|
print(f"⚠️ No frames extracted for subtitle {sub.index}") |
|
|
|
|
|
|
|
|
if frame_counter == 1: |
|
|
print("🚨 No story-relevant frames generated – falling back to uniform extraction…") |
|
|
try: |
|
|
|
|
|
video_cap = cv2.VideoCapture(video) |
|
|
total_frames = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
step = max(total_frames // 16, 1) |
|
|
extracted = 0 |
|
|
frame_idx = 0 |
|
|
while extracted < 16 and video_cap.isOpened(): |
|
|
video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
|
ret, frame = video_cap.read() |
|
|
if not ret: |
|
|
break |
|
|
out_path = os.path.join(final_dir, f"frame{frame_counter:03}.png") |
|
|
cv2.imwrite(out_path, frame) |
|
|
frame_counter += 1 |
|
|
extracted += 1 |
|
|
frame_idx += step |
|
|
video_cap.release() |
|
|
print(f"✅ Fallback extracted {extracted} uniform frames") |
|
|
except Exception as e: |
|
|
print(f"Fallback extraction failed: {e}") |
|
|
|
|
|
print(f"✅ Generated {frame_counter-1} story-relevant frames") |
|
|
|
|
|
except TimeoutError: |
|
|
print("⏰ Keyframe generation timed out, using fallback method...") |
|
|
|
|
|
for i, sub in enumerate(subs[:4], 1): |
|
|
if frame_counter <= 16: |
|
|
try: |
|
|
|
|
|
frames = extract_frames(video, os.path.join("frames", f"sub{sub.index}"), |
|
|
sub.start.total_seconds(), sub.end.total_seconds(), 1) |
|
|
if frames: |
|
|
copy_and_rename_file(frames[0], final_dir, f"frame{frame_counter:03}.png") |
|
|
print(f"📖 Fallback frame {frame_counter}: {sub.content}") |
|
|
frame_counter += 1 |
|
|
except: |
|
|
pass |
|
|
|
|
|
print(f"✅ Generated {frame_counter-1} fallback frames") |
|
|
|
|
|
finally: |
|
|
|
|
|
signal.alarm(0) |
|
|
|
|
|
def _select_story_relevant_frames(frames, highlight_scores, subtitle): |
|
|
"""Enhanced story-aware frame selection""" |
|
|
try: |
|
|
highlight_scores = list(highlight_scores) |
|
|
|
|
|
|
|
|
sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)] |
|
|
|
|
|
|
|
|
story_scores = [] |
|
|
for i, frame_path in enumerate(frames): |
|
|
story_score = _analyze_story_relevance(frame_path, highlight_scores[i], subtitle) |
|
|
story_scores.append(story_score) |
|
|
|
|
|
|
|
|
combined_scores = [] |
|
|
for i in range(len(frames)): |
|
|
combined_score = (highlight_scores[i] * 0.6) + (story_scores[i] * 0.4) |
|
|
combined_scores.append(combined_score) |
|
|
|
|
|
|
|
|
sorted_combined = [i[0] for i in sorted(enumerate(combined_scores), key=lambda x: x[1], reverse=True)] |
|
|
|
|
|
|
|
|
num_frames_to_select = min(3, len(frames)) |
|
|
return sorted_combined[:num_frames_to_select] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Story selection failed: {e}") |
|
|
|
|
|
try: |
|
|
highlight_scores = list(highlight_scores) |
|
|
sorted_indices = [i[0] for i in sorted(enumerate(highlight_scores), key=lambda x: x[1], reverse=True)] |
|
|
return [sorted_indices[0]] if sorted_indices else [0] |
|
|
except: |
|
|
return [0] |
|
|
|
|
|
def _analyze_story_relevance(frame_path, ai_score, subtitle): |
|
|
"""Analyze frame for story relevance""" |
|
|
try: |
|
|
img = cv2.imread(frame_path) |
|
|
if img is None: |
|
|
return ai_score |
|
|
|
|
|
|
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
|
faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
|
|
face_score = len(faces) * 0.2 |
|
|
|
|
|
|
|
|
motion_score = _detect_motion(img) * 0.15 |
|
|
|
|
|
|
|
|
complexity_score = _analyze_scene_complexity(img) * 0.1 |
|
|
|
|
|
|
|
|
content_score = _analyze_subtitle_relevance(subtitle.content) * 0.15 |
|
|
|
|
|
|
|
|
story_score = ai_score + face_score + motion_score + complexity_score + content_score |
|
|
|
|
|
return min(story_score, 1.0) |
|
|
|
|
|
except Exception as e: |
|
|
return ai_score |
|
|
|
|
|
def _detect_motion(img): |
|
|
"""Detect motion/action in frame""" |
|
|
try: |
|
|
|
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
edges = cv2.Canny(gray, 50, 150) |
|
|
edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1]) |
|
|
return min(edge_density * 10, 1.0) |
|
|
except: |
|
|
return 0.0 |
|
|
|
|
|
def _analyze_scene_complexity(img): |
|
|
"""Analyze scene complexity""" |
|
|
try: |
|
|
|
|
|
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
|
|
l_channel = lab[:,:,0] |
|
|
complexity = np.std(l_channel) / 255.0 |
|
|
return min(complexity * 2, 1.0) |
|
|
except: |
|
|
return 0.0 |
|
|
|
|
|
def _analyze_subtitle_relevance(subtitle_text): |
|
|
"""Analyze subtitle content for story relevance""" |
|
|
|
|
|
important_keywords = [ |
|
|
'hello', 'goodbye', 'thank', 'please', 'sorry', 'yes', 'no', |
|
|
'love', 'hate', 'help', 'danger', 'important', 'secret', |
|
|
'action', 'fight', 'run', 'stop', 'go', 'come', 'leave' |
|
|
] |
|
|
|
|
|
text_lower = subtitle_text.lower() |
|
|
relevance_score = 0.0 |
|
|
|
|
|
for keyword in important_keywords: |
|
|
if keyword in text_lower: |
|
|
relevance_score += 0.1 |
|
|
|
|
|
return min(relevance_score, 1.0) |
|
|
|
|
|
|
|
|
def black_bar_crop(): |
|
|
ref_img_path = "frames/final/frame001.png" |
|
|
|
|
|
|
|
|
if not os.path.exists(ref_img_path): |
|
|
print(f"❌ Reference image not found: {ref_img_path}") |
|
|
return 0, 0, 0, 0 |
|
|
|
|
|
x, y, w, h = get_black_bar_coordinates(ref_img_path) |
|
|
|
|
|
|
|
|
folder_dir = "frames/final" |
|
|
if not os.path.exists(folder_dir): |
|
|
print(f"❌ Frames directory not found: {folder_dir}") |
|
|
return x, y, w, h |
|
|
|
|
|
for image in os.listdir(folder_dir): |
|
|
img_path = os.path.join("frames",'final',image) |
|
|
if os.path.exists(img_path): |
|
|
image_data = cv2.imread(img_path) |
|
|
if image_data is not None: |
|
|
|
|
|
crop = image_data[y:y+h, x:x+w] |
|
|
|
|
|
cv2.imwrite(img_path, crop) |
|
|
|
|
|
return x, y, w, h |