Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from llama_cpp import Llama | |
| from qdrant_client import QdrantClient | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import cv2 | |
| import os | |
| import tempfile | |
| import uuid | |
| import re | |
| import subprocess | |
| import time | |
| # Configuration | |
| QDRANT_COLLECTION_NAME = "video_frames" | |
| VIDEO_SEGMENT_DURATION = 60 | |
| # Load Qdrant key | |
| QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
| if not QDRANT_API_KEY: | |
| print("Error: QDRANT_API_KEY environment variable not found.") | |
| print("Please add your Qdrant API key as a secret named 'QDRANT_API_KEY' in your Hugging Face Space settings.") | |
| raise ValueError("QDRANT_API_KEY environment variable not set.") | |
| print("Initializing LLM...") | |
| try: | |
| llm = Llama.from_pretrained( | |
| repo_id="m1tch/gemma-finetune-ai_class_gguf", | |
| filename="gemma-3_ai_class.Q8_0.gguf", | |
| n_gpu_layers=-1, | |
| n_ctx=2048, | |
| verbose=False | |
| ) | |
| print("LLM initialized successfully.") | |
| except Exception as e: | |
| print(f"Error initializing LLM: {e}") | |
| raise | |
| print("Connecting to Qdrant...") | |
| try: | |
| qdrant_client = QdrantClient( | |
| url="https://2c18d413-cbb5-441c-b060-4c8c2302dcde.us-east4-0.gcp.cloud.qdrant.io:6333/", | |
| api_key=QDRANT_API_KEY, | |
| timeout=60 | |
| ) | |
| qdrant_client.get_collections() | |
| print("Qdrant connection successful.") | |
| except Exception as e: | |
| print(f"Error connecting to Qdrant: {e}") | |
| raise | |
| print("Loading dataset stream...") | |
| try: | |
| # Load video dataset | |
| dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True) | |
| print(f"Dataset loaded. First item example: {next(iter(dataset))['__key__']}") | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| raise | |
| try: | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("Sentence Transformer model loaded.") | |
| except Exception as e: | |
| print(f"Error loading Sentence Transformer model: {e}") | |
| raise | |
| def rag_query(client, collection_name, query_text, top_k=5, filter_condition=None): | |
| """ | |
| Test RAG by querying the vector database with text. Returns a dictionary with search results and metadata. | |
| Uses the pre-loaded embedding_model. | |
| """ | |
| try: | |
| query_vector = embedding_model.encode(query_text).tolist() | |
| search_params = { | |
| "collection_name": collection_name, | |
| "query_vector": query_vector, | |
| "limit": top_k, | |
| "with_payload": True, | |
| "with_vectors": False | |
| } | |
| if filter_condition: | |
| search_params["filter"] = filter_condition | |
| search_results = client.search(**search_params) | |
| formatted_results = [] | |
| for idx, result in enumerate(search_results): | |
| formatted_results.append({ | |
| "rank": idx + 1, | |
| "score": result.score, | |
| "video_id": result.payload.get("video_id"), | |
| "timestamp": result.payload.get("timestamp"), | |
| "subtitle": result.payload.get("subtitle"), | |
| "frame_number": result.payload.get("frame_number") | |
| }) | |
| return { | |
| "query": query_text, | |
| "results": formatted_results, | |
| "avg_score": sum(r.score for r in search_results) / len(search_results) if search_results else 0 | |
| } | |
| except Exception as e: | |
| print(f"Error during RAG query: {e}") | |
| return {"error": str(e), "query": query_text, "results": []} | |
| def extract_video_segment(video_id, start_time, duration, dataset): | |
| """ | |
| Generator function that extracts and yields a single video segment file path. | |
| Modified to return a single path suitable for Gradio. | |
| """ | |
| target_id = str(video_id) | |
| target_key = f"videos/{target_id}/{target_id}" | |
| start_time = float(start_time) | |
| duration = float(duration) | |
| unique_id = str(uuid.uuid4()) | |
| temp_dir = os.path.join(tempfile.gettempdir(), f"gradio_video_{unique_id}") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_video_path = os.path.join(temp_dir, f"{target_id}_full_{unique_id}.mp4") | |
| output_path_opencv = os.path.join(temp_dir, f"output_opencv_{unique_id}.mp4") | |
| output_path_ffmpeg = os.path.join(temp_dir, f"output_ffmpeg_{unique_id}.mp4") | |
| print(f"Attempting to extract segment for video_id={target_id}, start={start_time}, duration={duration}") | |
| print(f"Looking for dataset key: {target_key}") | |
| print(f"Temporary directory: {temp_dir}") | |
| try: | |
| found = False | |
| retries = 3 | |
| dataset_iterator = iter(dataset) | |
| for _ in range(retries * 100): | |
| try: | |
| sample = next(dataset_iterator) | |
| if '__key__' in sample and sample['__key__'] == target_key: | |
| found = True | |
| print(f"Found video key {target_key}. Saving to {temp_video_path}...") | |
| with open(temp_video_path, 'wb') as f: | |
| f.write(sample['mp4']) | |
| print(f"Video saved successfully ({os.path.getsize(temp_video_path)} bytes).") | |
| break | |
| except StopIteration: | |
| print("Reached end of dataset stream without finding the video.") | |
| break | |
| except Exception as e: | |
| print(f"Error iterating dataset: {e}") | |
| time.sleep(1) | |
| if not found: | |
| print(f"Could not find video with ID {target_id} (key: {target_key}) in the dataset stream after {_ + 1} attempts.") | |
| return None | |
| # Process the saved video | |
| if not os.path.exists(temp_video_path) or os.path.getsize(temp_video_path) == 0: | |
| print(f"Temporary video file {temp_video_path} is missing or empty.") | |
| return None | |
| cap = cv2.VideoCapture(temp_video_path) | |
| if not cap.isOpened(): | |
| print(f"Error opening video file with OpenCV: {temp_video_path}") | |
| return None | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps <= 0: | |
| print(f"Warning: Invalid FPS ({fps}) detected for {temp_video_path}. Assuming 30 FPS.") | |
| fps = 30 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_vid_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| vid_duration = total_vid_frames / fps if fps > 0 else 0 | |
| print(f"Video properties: {width}x{height} @ {fps:.2f}fps, Total Duration: {vid_duration:.2f}s") | |
| start_frame = int(start_time * fps) | |
| end_frame = int((start_time + duration) * fps) | |
| # Clamp frame numbers to valid range | |
| start_frame = max(0, start_frame) | |
| end_frame = min(total_vid_frames, end_frame) | |
| if start_frame >= total_vid_frames or start_frame >= end_frame: | |
| print(f"Calculated start frame ({start_frame}) is beyond video length ({total_vid_frames}) or segment is invalid.") | |
| cap.release() | |
| return None | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
| frames_to_write = end_frame - start_frame | |
| print(f"Extracting frames from {start_frame} to {end_frame} ({frames_to_write} frames)") | |
| # Try OpenCV first | |
| fourcc_opencv = cv2.VideoWriter_fourcc(*'mp4v') # mp4v is often more compatible than avc1 with base OpenCV | |
| out_opencv = cv2.VideoWriter(output_path_opencv, fourcc_opencv, fps, (width, height)) | |
| if not out_opencv.isOpened(): | |
| print("Error opening OpenCV VideoWriter with mp4v.") | |
| cap.release() | |
| return None | |
| frames_written_opencv = 0 | |
| while frames_written_opencv < frames_to_write: | |
| ret, frame = cap.read() | |
| if not ret: | |
| print("Warning: Ran out of frames before reaching target end frame.") | |
| break | |
| out_opencv.write(frame) | |
| frames_written_opencv += 1 | |
| out_opencv.release() | |
| print(f"OpenCV finished writing {frames_written_opencv} frames to {output_path_opencv}") | |
| cap.release() | |
| # FFmpeg | |
| final_output_path = None | |
| try: | |
| cmd = [ | |
| 'ffmpeg', | |
| '-ss', str(start_time), # Start time | |
| '-i', temp_video_path, # Input file (original downloaded) | |
| '-t', str(duration), # Duration of the segment | |
| '-c:v', 'libx264', | |
| '-profile:v', 'baseline', | |
| '-level', '3.0', | |
| '-preset', 'fast', | |
| '-pix_fmt', 'yuv420p', | |
| '-movflags', '+faststart', | |
| '-c:a', 'aac', | |
| '-b:a', '128k', | |
| '-y', | |
| output_path_ffmpeg | |
| ] | |
| print(f"Running FFmpeg command: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) # Add timeout | |
| if result.returncode == 0 and os.path.exists(output_path_ffmpeg) and os.path.getsize(output_path_ffmpeg) > 0: | |
| print(f"FFmpeg processing successful. Output: {output_path_ffmpeg}") | |
| final_output_path = output_path_ffmpeg | |
| else: | |
| print(f"FFmpeg error (Return Code: {result.returncode}):") | |
| print(f"FFmpeg stdout:\n{result.stdout}") | |
| print(f"FFmpeg stderr:\n{result.stderr}") | |
| print("Falling back to OpenCV output.") | |
| if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
| final_output_path = output_path_opencv | |
| else: | |
| print("OpenCV output is also invalid or empty.") | |
| final_output_path = None | |
| except subprocess.TimeoutExpired: | |
| print("FFmpeg command timed out.") | |
| print("Falling back to OpenCV output.") | |
| if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
| final_output_path = output_path_opencv | |
| else: | |
| print("OpenCV output is also invalid or empty.") | |
| final_output_path = None | |
| except FileNotFoundError: | |
| print("Error: ffmpeg command not found. Make sure FFmpeg is installed and in your system's PATH.") | |
| print("Falling back to OpenCV output.") | |
| if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
| final_output_path = output_path_opencv | |
| else: | |
| print("OpenCV output is also invalid or empty.") | |
| final_output_path = None | |
| except Exception as e: | |
| print(f"An unexpected error occurred during FFmpeg processing: {e}") | |
| print("Falling back to OpenCV output.") | |
| if os.path.exists(output_path_opencv) and os.path.getsize(output_path_opencv) > 0: | |
| final_output_path = output_path_opencv | |
| else: | |
| print("OpenCV output is also invalid or empty.") | |
| final_output_path = None | |
| if os.path.exists(temp_video_path): | |
| try: | |
| os.remove(temp_video_path) | |
| print(f"Cleaned up temporary full video: {temp_video_path}") | |
| except Exception as e: | |
| print(f"Warning: Could not remove temporary file {temp_video_path}: {e}") | |
| # If FFmpeg failed | |
| if final_output_path != output_path_ffmpeg and os.path.exists(output_path_ffmpeg): | |
| try: | |
| os.remove(output_path_ffmpeg) | |
| except Exception as e: | |
| print(f"Warning: Could not remove failed ffmpeg output {output_path_ffmpeg}: {e}") | |
| print(f"Returning video segment path: {final_output_path}") | |
| return final_output_path | |
| except Exception as e: | |
| print(f"Error processing video segment for {video_id}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if 'cap' in locals() and cap.isOpened(): cap.release() | |
| if 'out_opencv' in locals() and out_opencv.isOpened(): out_opencv.release() | |
| if os.path.exists(temp_video_path): os.remove(temp_video_path) | |
| if os.path.exists(output_path_opencv): os.remove(output_path_opencv) | |
| if os.path.exists(output_path_ffmpeg): os.remove(output_path_ffmpeg) | |
| return None | |
| QDRANT_COLLECTION_NAME = "video_frames" | |
| VIDEO_SEGMENT_DURATION = 40 # Extract 40 seconds around the timestamp | |
| def parse_llm_output(text): | |
| """ | |
| Parses the LLM's structured output using a mix of regex for simple | |
| fields (video_id, timestamp) and string manipulation for reasoning | |
| as a workaround for regex matching issues. | |
| """ | |
| data = {} | |
| # Parse video_id and timestamp with regex | |
| simple_patterns = { | |
| 'video_id': r"\{Best Result:\s*\[?([^\]\}]+)\]?\s*\}", | |
| 'timestamp': r"\{Timestamp:\s*\[?([^\]\}]+)\]?\s*\}", | |
| } | |
| for key, pattern in simple_patterns.items(): | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| value = match.group(1).strip() | |
| value = value.strip('\'"“”') | |
| data[key] = value | |
| else: | |
| print(f"Warning: Could not parse '{key}' using regex pattern: {pattern}") | |
| data[key] = None | |
| # Parse reasoning | |
| reasoning_value = None | |
| try: | |
| key_marker_lower = "{reasoning:" | |
| start_index = text.lower().find(key_marker_lower) | |
| if start_index != -1: | |
| search_start_for_brace = start_index + len(key_marker_lower) | |
| end_index = text.find('}', search_start_for_brace) | |
| if end_index != -1: | |
| actual_marker_end = start_index + len(key_marker_lower) | |
| value = text[actual_marker_end : end_index] | |
| value = value.strip() | |
| if value.startswith('[') and value.endswith(']'): | |
| value = value[1:-1] | |
| value = value.strip('\'"“”') | |
| value = value.strip() | |
| reasoning_value = value | |
| else: | |
| print("Warning: Found '{reasoning:' marker but no closing '}' found afterwards.") | |
| else: | |
| print("Warning: Marker '{reasoning:' not found in text.") | |
| except Exception as e: | |
| print(f"Error during string manipulation parsing for reasoning: {e}") | |
| data['reasoning'] = reasoning_value | |
| if data.get('timestamp'): | |
| try: | |
| float(data['timestamp']) | |
| except ValueError: | |
| print(f"Warning: Parsed timestamp '{data['timestamp']}' is not a valid number.") | |
| print(f"Parsed LLM output (Using String Manipulation for Reasoning): {data}") | |
| return data | |
| def process_query_and_get_video(query_text): | |
| """ | |
| Orchestrates RAG, LLM query, parsing, and video extraction. | |
| """ | |
| print(f"\n--- Processing query: '{query_text}' ---") | |
| # 1. RAG Query | |
| print("Step 1: Performing RAG query...") | |
| rag_results = rag_query(qdrant_client, QDRANT_COLLECTION_NAME, query_text) | |
| if "error" in rag_results or not rag_results.get("results"): | |
| error_msg = rag_results.get('error', 'No relevant segments found by RAG.') | |
| print(f"RAG Error/No Results: {error_msg}") | |
| return f"Error during RAG search: {error_msg}", None | |
| print(f"RAG query successful. Found {len(rag_results['results'])} results.") | |
| # Format LLM Prompt | |
| print("Step 2: Formatting prompt for LLM...") | |
| prompt = f"""You are tasked with selecting the most relevant information from a set of video subtitle segments to answer a query. | |
| QUERY (also seen below): "{query_text}" | |
| For each result provided, evaluate how well it directly addresses the definition or explanation related to the query. Pay attention to: | |
| 1. Clarity of explanation | |
| 2. Relevance to the query | |
| 3. Completeness of information | |
| From the provided results, select the SINGLE BEST match that most directly answers the query. | |
| Format your response STRICTLY as follows, with each field on a new line: | |
| {{Best Result: [video_id]}} | |
| {{Timestamp: [timestamp]}} | |
| {{Content: [subtitle text]}} | |
| {{Reasoning: [Brief explanation of why this result best answers the query]}} | |
| {rag_results}""" | |
| # 3. Call LLM | |
| print("Step 3: Querying the LLM...") | |
| try: | |
| output = llm.create_chat_completion( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant designed to select the best video segment based on relevance to a query, following a specific output format."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.1, | |
| max_tokens=300 | |
| ) | |
| llm_response_text = output['choices'][0]['message']['content'] | |
| print(f"LLM Response:\n{llm_response_text}") | |
| except Exception as e: | |
| print(f"Error during LLM call: {e}") | |
| return f"Error calling LLM: {e}", None | |
| # 4. Parse LLM Response | |
| print("Step 4: Parsing LLM response...") | |
| parsed_data = parse_llm_output(llm_response_text) | |
| video_id = parsed_data.get('video_id') | |
| timestamp_str = parsed_data.get('timestamp') | |
| reasoning = parsed_data.get('reasoning') | |
| if not video_id or not timestamp_str: | |
| print("Error: Could not parse required video_id or timestamp from LLM response.") | |
| fallback_reasoning = reasoning if reasoning else "Could not determine the best segment." | |
| error_msg = f"Failed to parse LLM response. LLM said:\n---\n{llm_response_text}\n---\nReasoning (if found): {fallback_reasoning}" | |
| return error_msg, None | |
| try: | |
| timestamp = float(timestamp_str) | |
| # Adjust timestamp slightly - start a bit earlier if possible | |
| start_time = max(0.0, timestamp - (VIDEO_SEGMENT_DURATION / 4)) | |
| except ValueError: | |
| print(f"Error: Could not convert parsed timestamp '{timestamp_str}' to float.") | |
| error_msg = f"Invalid timestamp format from LLM ('{timestamp_str}'). LLM reasoning (if found): {reasoning}" | |
| return error_msg, None | |
| final_reasoning = reasoning if reasoning else "No reasoning provided by LLM." | |
| # Extract Video Segment | |
| print(f"Step 5: Extracting video segment (ID: {video_id}, Start: {start_time:.2f}s, Duration: {VIDEO_SEGMENT_DURATION}s)...") | |
| global dataset | |
| video_path = extract_video_segment(video_id, start_time, VIDEO_SEGMENT_DURATION, dataset) | |
| if video_path and os.path.exists(video_path): | |
| print(f"Video segment extracted successfully: {video_path}") | |
| return final_reasoning, video_path | |
| else: | |
| print("Failed to extract video segment.") | |
| error_msg = f"{final_reasoning}\n\n(However, failed to extract the corresponding video segment for ID {video_id} at timestamp {timestamp_str}.)" | |
| return error_msg, None | |
| with gr.Blocks() as iface: | |
| gr.Markdown( | |
| """ | |
| # Lecture Videos Q&A | |
| Ask a question about the lectures. The system will find relevant segments using RAG | |
| and a fine-tuned LLM to select the best one, and display the corresponding video clip. | |
| """ | |
| ) | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Your Question", placeholder="Using only the videos, explain how ResNets work.") | |
| submit_button = gr.Button("Ask & Find Video") | |
| with gr.Row(): | |
| reasoning_output = gr.Markdown(label="LLM Reasoning") | |
| with gr.Row(): | |
| video_output = gr.Video(label="Relevant Video Segment") | |
| submit_button.click( | |
| fn=process_query_and_get_video, | |
| inputs=query_input, | |
| outputs=[reasoning_output, video_output] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "Using only the videos, explain how ResNets work.", | |
| "Using only the videos, explain the advantages of CNNs over fully connected networks.", | |
| "Using only the videos, explain the the binary cross entropy loss function.", | |
| ], | |
| inputs=query_input, | |
| outputs=[reasoning_output, video_output], | |
| fn=process_query_and_get_video, | |
| cache_examples=False, | |
| ) | |
| print("Launching Gradio interface...") | |
| iface.launch(debug=True, share=False) |