Spaces:
Runtime error
Runtime error
| """Tests for the Video Analyzer application.""" | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| class TestChunkText: | |
| """Tests for the chunk_text function.""" | |
| def test_empty_text(self): | |
| """Test chunking empty text returns empty list.""" | |
| from app import chunk_text | |
| result = chunk_text("") | |
| assert result == [] | |
| def test_single_word(self): | |
| """Test chunking single word returns one chunk.""" | |
| from app import chunk_text | |
| result = chunk_text("hello") | |
| assert result == ["hello"] | |
| def test_text_smaller_than_chunk_size(self): | |
| """Test text smaller than chunk size returns one chunk.""" | |
| from app import chunk_text | |
| text = "This is a short sentence." | |
| result = chunk_text(text, chunk_size=100, overlap=10) | |
| assert len(result) == 1 | |
| assert result[0] == text | |
| def test_text_larger_than_chunk_size(self): | |
| """Test text larger than chunk size returns multiple chunks.""" | |
| from app import chunk_text | |
| words = ["word"] * 100 | |
| text = " ".join(words) | |
| result = chunk_text(text, chunk_size=30, overlap=5) | |
| assert len(result) > 1 | |
| def test_overlap_creates_overlapping_chunks(self): | |
| """Test that overlap parameter creates overlapping content.""" | |
| from app import chunk_text | |
| words = [f"word{i}" for i in range(20)] | |
| text = " ".join(words) | |
| result = chunk_text(text, chunk_size=10, overlap=3) | |
| # Check that chunks overlap | |
| if len(result) > 1: | |
| first_chunk_words = set(result[0].split()) | |
| second_chunk_words = set(result[1].split()) | |
| overlap = first_chunk_words & second_chunk_words | |
| assert len(overlap) > 0 | |
| def test_default_parameters(self): | |
| """Test default chunk_size=500 and overlap=50.""" | |
| from app import chunk_text | |
| words = ["word"] * 600 | |
| text = " ".join(words) | |
| result = chunk_text(text) | |
| assert len(result) >= 2 | |
| class TestGetDevice: | |
| """Tests for the get_device function.""" | |
| def test_returns_cuda_when_available(self): | |
| """Test returns 'cuda' when CUDA is available.""" | |
| with patch("torch.cuda.is_available", return_value=True): | |
| from app import get_device | |
| assert get_device() == "cuda" | |
| def test_returns_cpu_when_cuda_unavailable(self): | |
| """Test returns 'cpu' when CUDA is not available.""" | |
| with patch("torch.cuda.is_available", return_value=False): | |
| from app import get_device | |
| assert get_device() == "cpu" | |
| class TestHello: | |
| """Tests for the hello function.""" | |
| def test_no_profile_returns_login_message(self): | |
| """Test returns login message when profile is None.""" | |
| from app import hello | |
| result = hello(None) | |
| assert result == "Please log in to continue." | |
| def test_with_profile_returns_greeting(self): | |
| """Test returns greeting with user's name when profile exists.""" | |
| from app import hello | |
| mock_profile = MagicMock() | |
| mock_profile.name = "TestUser" | |
| result = hello(mock_profile) | |
| assert result == "Hello TestUser!" | |
| class TestTranscribeAudio: | |
| """Tests for the transcribe_audio function.""" | |
| def test_nonexistent_file_returns_empty_string(self): | |
| """Test returns empty string for non-existent audio file.""" | |
| from app import transcribe_audio | |
| result = transcribe_audio("/nonexistent/path/audio.mp3", MagicMock()) | |
| assert result == "" | |
| def test_existing_file_calls_whisper(self): | |
| """Test calls whisper model for existing file.""" | |
| from app import transcribe_audio | |
| with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: | |
| temp_path = f.name | |
| f.write(b"fake audio data") | |
| try: | |
| mock_whisper = MagicMock() | |
| mock_whisper.return_value = {"text": "Hello world"} | |
| result = transcribe_audio(temp_path, mock_whisper) | |
| assert result == "Hello world" | |
| mock_whisper.assert_called_once() | |
| finally: | |
| os.unlink(temp_path) | |
| class TestGetKnowledgeStats: | |
| """Tests for the get_knowledge_stats function.""" | |
| def test_empty_knowledge_base(self): | |
| """Test returns appropriate message when knowledge base is empty.""" | |
| with patch("app.collection") as mock_collection: | |
| mock_collection.count.return_value = 0 | |
| from app import get_knowledge_stats | |
| # Directly test the function logic | |
| result = get_knowledge_stats() | |
| # The function uses the global collection, so we need to patch it at module level | |
| assert "empty" in result.lower() or "0" in result | |
| def test_populated_knowledge_base(self): | |
| """Test returns count when knowledge base has content.""" | |
| with patch("app.collection") as mock_collection: | |
| mock_collection.count.return_value = 42 | |
| from app import get_knowledge_stats | |
| result = get_knowledge_stats() | |
| assert "42" in result or "chunks" in result.lower() | |
| class TestExtractAudio: | |
| """Tests for the extract_audio function.""" | |
| def test_successful_extraction(self): | |
| """Test successful audio extraction.""" | |
| with patch("app.subprocess.run") as mock_run: | |
| mock_run.return_value = MagicMock(returncode=0) | |
| from app import extract_audio | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| video_path = os.path.join(tmpdir, "test.mp4") | |
| with open(video_path, "w") as f: | |
| f.write("fake video") | |
| # Create the expected output file | |
| audio_path = os.path.join(tmpdir, "audio.mp3") | |
| with open(audio_path, "w") as f: | |
| f.write("fake audio") | |
| result = extract_audio(video_path, tmpdir) | |
| assert result == audio_path | |
| def test_ffmpeg_not_found(self): | |
| """Test raises RuntimeError when FFmpeg is not found.""" | |
| with patch("app.subprocess.run") as mock_run: | |
| mock_run.side_effect = FileNotFoundError() | |
| from app import extract_audio | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| video_path = os.path.join(tmpdir, "test.mp4") | |
| with open(video_path, "w") as f: | |
| f.write("fake video") | |
| with pytest.raises(RuntimeError, match="FFmpeg not found"): | |
| extract_audio(video_path, tmpdir) | |
| class TestDownloadVideo: | |
| """Tests for the download_video function.""" | |
| def test_invalid_url_raises_error(self): | |
| """Test raises RuntimeError for invalid URL.""" | |
| with patch("app.yt_dlp.YoutubeDL") as mock_ydl: | |
| mock_instance = MagicMock() | |
| mock_instance.extract_info.return_value = None | |
| mock_ydl.return_value.__enter__.return_value = mock_instance | |
| from app import download_video | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| with pytest.raises(RuntimeError, match="Failed to download"): | |
| download_video("invalid_url", tmpdir) | |
| def test_single_video_download(self): | |
| """Test downloading a single video.""" | |
| with patch("app.yt_dlp.YoutubeDL") as mock_ydl: | |
| mock_instance = MagicMock() | |
| mock_instance.extract_info.return_value = { | |
| "title": "Test Video", | |
| "ext": "mp4", | |
| "duration": 120, | |
| } | |
| mock_ydl.return_value.__enter__.return_value = mock_instance | |
| from app import download_video | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| result = download_video("https://youtube.com/watch?v=test", tmpdir) | |
| assert len(result) == 1 | |
| assert result[0]["title"] == "Test Video" | |
| assert result[0]["duration"] == 120 | |
| def test_playlist_download(self): | |
| """Test downloading a playlist.""" | |
| with patch("app.yt_dlp.YoutubeDL") as mock_ydl: | |
| mock_instance = MagicMock() | |
| mock_instance.extract_info.return_value = { | |
| "entries": [ | |
| {"title": "Video 1", "ext": "mp4", "duration": 60}, | |
| {"title": "Video 2", "ext": "mp4", "duration": 90}, | |
| ] | |
| } | |
| mock_ydl.return_value.__enter__.return_value = mock_instance | |
| from app import download_video | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| result = download_video("https://youtube.com/playlist?list=test", tmpdir) | |
| assert len(result) == 2 | |
| assert result[0]["title"] == "Video 1" | |
| assert result[1]["title"] == "Video 2" | |
| class TestChatWithVideos: | |
| """Tests for the chat_with_videos function.""" | |
| def test_no_profile_returns_login_message(self): | |
| """Test returns login message when profile is None.""" | |
| from app import chat_with_videos | |
| result = chat_with_videos("test", [], None, MagicMock()) | |
| assert "log in" in result.lower() | |
| def test_no_token_returns_auth_message(self): | |
| """Test returns auth message when token is None.""" | |
| from app import chat_with_videos | |
| result = chat_with_videos("test", [], MagicMock(), None) | |
| assert "authentication" in result.lower() | |
| def test_empty_message_returns_prompt(self): | |
| """Test returns prompt for empty message.""" | |
| from app import chat_with_videos | |
| result = chat_with_videos("", [], MagicMock(), MagicMock()) | |
| assert "enter a question" in result.lower() | |
| def test_empty_knowledge_base_returns_message(self): | |
| """Test returns message when knowledge base is empty.""" | |
| with patch("app.collection") as mock_collection: | |
| mock_collection.count.return_value = 0 | |
| from app import chat_with_videos | |
| result = chat_with_videos("test query", [], MagicMock(), MagicMock()) | |
| assert "no videos have been analyzed" in result.lower() | |
| class TestUrlValidation: | |
| """Tests for YouTube URL validation.""" | |
| def test_valid_watch_url(self): | |
| """Test valid watch URL.""" | |
| from app import is_valid_youtube_url | |
| is_valid, result = is_valid_youtube_url("https://youtube.com/watch?v=abc123") | |
| assert is_valid is True | |
| def test_valid_short_url(self): | |
| """Test valid short URL.""" | |
| from app import is_valid_youtube_url | |
| is_valid, result = is_valid_youtube_url("https://youtu.be/abc123") | |
| assert is_valid is True | |
| def test_valid_playlist_url(self): | |
| """Test valid playlist URL.""" | |
| from app import is_valid_youtube_url | |
| is_valid, result = is_valid_youtube_url("https://youtube.com/playlist?list=abc") | |
| assert is_valid is True | |
| def test_empty_url(self): | |
| """Test empty URL returns error.""" | |
| from app import is_valid_youtube_url | |
| is_valid, result = is_valid_youtube_url("") | |
| assert is_valid is False | |
| assert "enter" in result.lower() | |
| def test_invalid_url(self): | |
| """Test invalid URL returns error.""" | |
| from app import is_valid_youtube_url | |
| is_valid, result = is_valid_youtube_url("https://example.com/video") | |
| assert is_valid is False | |
| def test_url_without_protocol(self): | |
| """Test URL without protocol gets normalized.""" | |
| from app import is_valid_youtube_url | |
| is_valid, result = is_valid_youtube_url("youtube.com/watch?v=abc123") | |
| assert is_valid is True | |
| assert result.startswith("https://") | |
| class TestProcessYoutube: | |
| """Tests for the process_youtube function.""" | |
| def test_no_profile_returns_login_message(self): | |
| """Test returns login message when profile is None.""" | |
| from app import process_youtube | |
| mock_progress = MagicMock() | |
| result = process_youtube("https://youtube.com/watch?v=test", 5, None, None, mock_progress) | |
| assert "log in" in result.lower() | |
| def test_empty_url_returns_prompt(self): | |
| """Test returns prompt for empty URL.""" | |
| from app import process_youtube | |
| mock_progress = MagicMock() | |
| result = process_youtube("", 5, MagicMock(), None, mock_progress) | |
| assert "enter" in result.lower() | |
| def test_invalid_url_returns_error(self): | |
| """Test returns error for invalid URL.""" | |
| from app import process_youtube | |
| mock_progress = MagicMock() | |
| result = process_youtube("not-a-url", 5, MagicMock(), None, mock_progress) | |
| assert "valid youtube url" in result.lower() | |
| class TestSessionState: | |
| """Tests for the SessionState class.""" | |
| def test_creates_collection_with_session_id(self): | |
| """Test SessionState creates a collection with session ID.""" | |
| from app import SessionState | |
| state = SessionState("test_session_123") | |
| assert state.session_id == "test_session_123" | |
| assert state.collection is not None | |
| def test_auto_generates_session_id(self): | |
| """Test SessionState generates session ID if not provided.""" | |
| from app import SessionState | |
| state = SessionState() | |
| assert state.session_id is not None | |
| assert len(state.session_id) == 32 # UUID hex length | |
| def test_clear_recreates_collection(self): | |
| """Test clear() recreates the collection.""" | |
| from app import SessionState | |
| state = SessionState("test_clear") | |
| # Add some data | |
| state.collection.add( | |
| documents=["test doc"], | |
| ids=["test_id"], | |
| ) | |
| assert state.collection.count() == 1 | |
| # Clear and verify | |
| state.clear() | |
| assert state.collection.count() == 0 | |
| def test_create_session_state_with_profile(self): | |
| """Test create_session_state uses profile name for consistent ID.""" | |
| from app import create_session_state | |
| mock_profile = MagicMock() | |
| mock_profile.name = "TestUser" | |
| state1 = create_session_state(mock_profile) | |
| state2 = create_session_state(mock_profile) | |
| # Same profile should get same session ID | |
| assert state1.session_id == state2.session_id | |
| def test_create_session_state_without_profile(self): | |
| """Test create_session_state generates random ID without profile.""" | |
| from app import create_session_state | |
| state1 = create_session_state(None) | |
| state2 = create_session_state(None) | |
| # Different calls should get different IDs | |
| assert state1.session_id != state2.session_id | |
| class TestRAGPipeline: | |
| """Integration tests for the RAG (Retrieval Augmented Generation) pipeline.""" | |
| def test_add_and_search_knowledge(self): | |
| """Test adding content and searching retrieves it.""" | |
| from app import SessionState, add_to_vector_db, search_knowledge | |
| state = SessionState("test_rag_1") | |
| # Add content | |
| add_to_vector_db( | |
| title="Test Video", | |
| transcript="This is a test about machine learning and neural networks.", | |
| visual_contexts=["A person standing at a whiteboard"], | |
| session_state=state, | |
| ) | |
| # Search should find relevant content | |
| results = search_knowledge("machine learning", session_state=state) | |
| assert len(results) > 0 | |
| assert any("machine learning" in r["content"].lower() for r in results) | |
| assert results[0]["title"] == "Test Video" | |
| def test_search_returns_empty_for_unrelated_query(self): | |
| """Test search returns empty for completely unrelated queries.""" | |
| from app import SessionState, add_to_vector_db, search_knowledge | |
| state = SessionState("test_rag_2") | |
| # Add specific content | |
| add_to_vector_db( | |
| title="Cooking Show", | |
| transcript="Today we will make a delicious pasta with tomato sauce.", | |
| visual_contexts=["Chef in kitchen"], | |
| session_state=state, | |
| ) | |
| # Search for something unrelated - should still return results but with low relevance | |
| results = search_knowledge("quantum physics equations", session_state=state) | |
| # ChromaDB will still return results, but they won't be highly relevant | |
| # The key test is that the system doesn't crash | |
| assert isinstance(results, list) | |
| def test_visual_contexts_are_searchable(self): | |
| """Test that visual context descriptions are searchable.""" | |
| from app import SessionState, add_to_vector_db, search_knowledge | |
| state = SessionState("test_rag_3") | |
| # Add content with visual context | |
| add_to_vector_db( | |
| title="Nature Documentary", | |
| transcript="", | |
| visual_contexts=["A majestic elephant walking through the savanna"], | |
| session_state=state, | |
| ) | |
| # Search for visual content | |
| results = search_knowledge("elephant savanna", session_state=state) | |
| assert len(results) > 0 | |
| assert any("elephant" in r["content"].lower() for r in results) | |
| assert results[0]["type"] == "visual" | |
| def test_multiple_videos_searchable(self): | |
| """Test that content from multiple videos is searchable.""" | |
| from app import SessionState, add_to_vector_db, search_knowledge | |
| state = SessionState("test_rag_4") | |
| # Add content from two videos | |
| add_to_vector_db( | |
| title="Python Tutorial", | |
| transcript="Learn Python programming with functions and classes.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| add_to_vector_db( | |
| title="JavaScript Guide", | |
| transcript="Master JavaScript with callbacks and promises.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| # Search should find Python content | |
| python_results = search_knowledge("Python functions", session_state=state) | |
| assert any("Python" in r["title"] for r in python_results) | |
| # Search should find JavaScript content | |
| js_results = search_knowledge("JavaScript promises", session_state=state) | |
| assert any("JavaScript" in r["title"] for r in js_results) | |
| def test_session_isolation(self): | |
| """Test that different sessions have isolated knowledge bases.""" | |
| from app import SessionState, add_to_vector_db, search_knowledge | |
| state1 = SessionState("isolation_test_1") | |
| state2 = SessionState("isolation_test_2") | |
| # Add content only to state1 | |
| add_to_vector_db( | |
| title="Session 1 Only", | |
| transcript="Unique content about dragons and wizards.", | |
| visual_contexts=[], | |
| session_state=state1, | |
| ) | |
| # State1 should find it | |
| results1 = search_knowledge("dragons wizards", session_state=state1) | |
| assert len(results1) > 0 | |
| # State2 should not find anything | |
| results2 = search_knowledge("dragons wizards", session_state=state2) | |
| assert len(results2) == 0 | |
| class TestChatWithVideosIntegration: | |
| """Integration tests for the chat_with_videos function with actual RAG.""" | |
| def test_chat_retrieves_relevant_context(self): | |
| """Test that chat retrieves relevant context from knowledge base.""" | |
| from app import SessionState, add_to_vector_db, chat_with_videos | |
| state = SessionState("chat_test_1") | |
| # Add content | |
| add_to_vector_db( | |
| title="AI Lecture", | |
| transcript="Artificial intelligence is transforming healthcare. Machine learning models can diagnose diseases.", | |
| visual_contexts=["Professor presenting slides about AI"], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| # Mock the InferenceClient | |
| with patch("app.InferenceClient") as mock_client: | |
| mock_response = MagicMock() | |
| mock_response.choices = [MagicMock()] | |
| mock_response.choices[0].message.content = "AI is transforming healthcare by enabling better diagnosis." | |
| mock_client.return_value.chat.completions.create.return_value = mock_response | |
| result = chat_with_videos( | |
| message="What is AI used for in healthcare?", | |
| history=[], | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| session_state=state, | |
| ) | |
| # Should get a response (not an error message) | |
| assert "AI" in result or "healthcare" in result | |
| assert "Sources:" in result | |
| assert "AI Lecture" in result | |
| def test_chat_includes_model_info(self): | |
| """Test that chat response includes model information.""" | |
| from app import SessionState, add_to_vector_db, chat_with_videos | |
| state = SessionState("chat_test_2") | |
| add_to_vector_db( | |
| title="Test Video", | |
| transcript="Some test content here.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| with patch("app.InferenceClient") as mock_client: | |
| mock_response = MagicMock() | |
| mock_response.choices = [MagicMock()] | |
| mock_response.choices[0].message.content = "Test response." | |
| mock_client.return_value.chat.completions.create.return_value = mock_response | |
| result = chat_with_videos( | |
| message="Tell me about the test content", | |
| history=[], | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| session_state=state, | |
| ) | |
| # Should include model info | |
| assert "Model:" in result | |
| def test_chat_handles_api_error(self): | |
| """Test that chat handles API errors gracefully.""" | |
| from app import SessionState, add_to_vector_db, chat_with_videos | |
| state = SessionState("chat_test_3") | |
| add_to_vector_db( | |
| title="Test Video", | |
| transcript="Some content.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| with patch("app.InferenceClient") as mock_client: | |
| # Simulate API error for all models | |
| mock_client.return_value.chat.completions.create.side_effect = Exception("503 Service Unavailable") | |
| result = chat_with_videos( | |
| message="Test question", | |
| history=[], | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| session_state=state, | |
| ) | |
| # Should return error message | |
| assert "unavailable" in result.lower() or "error" in result.lower() | |
| class TestHandleChat: | |
| """Integration tests for the unified handle_chat function.""" | |
| def test_detects_youtube_url(self): | |
| """Test that handle_chat detects YouTube URLs.""" | |
| from app import SessionState, handle_chat | |
| state = SessionState("handle_test_1") | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| # The URL processing will fail (no actual video), but it should detect it as a URL | |
| with patch("app._process_youtube_impl") as mock_process: | |
| mock_process.return_value = "## Test Video\n\nTranscript here" | |
| history, msg, new_state = handle_chat( | |
| message="https://youtube.com/watch?v=test123", | |
| history=[], | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| # Should have called process_youtube | |
| mock_process.assert_called_once() | |
| # Should have added messages to history | |
| assert len(history) >= 2 # User message + assistant response | |
| def test_detects_question(self): | |
| """Test that handle_chat detects questions (non-URLs).""" | |
| from app import SessionState, handle_chat | |
| state = SessionState("handle_test_2") | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| # Empty knowledge base - should prompt to add videos | |
| history, msg, new_state = handle_chat( | |
| message="What is this video about?", | |
| history=[], | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| # Should have response about no videos analyzed | |
| assert len(history) >= 2 | |
| last_response = history[-1]["content"] | |
| assert "don't have any videos" in last_response.lower() or "paste a youtube url" in last_response.lower() | |
| def test_answers_question_with_knowledge(self): | |
| """Test that handle_chat answers questions when knowledge base has content.""" | |
| from app import SessionState, add_to_vector_db, handle_chat | |
| state = SessionState("handle_test_3") | |
| # Pre-populate knowledge base | |
| add_to_vector_db( | |
| title="Cooking Video", | |
| transcript="Today we make pasta. Boil water, add salt, cook for 10 minutes.", | |
| visual_contexts=["Chef stirring pot"], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| with patch("app.InferenceClient") as mock_client: | |
| mock_response = MagicMock() | |
| mock_response.choices = [MagicMock()] | |
| mock_response.choices[0].message.content = "To cook pasta, boil water and add salt." | |
| mock_client.return_value.chat.completions.create.return_value = mock_response | |
| history, msg, new_state = handle_chat( | |
| message="How do I cook pasta?", | |
| history=[], | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| # Should have a meaningful response | |
| assert len(history) >= 2 | |
| last_response = history[-1]["content"] | |
| assert "pasta" in last_response.lower() or "cook" in last_response.lower() | |
| def test_requires_login(self): | |
| """Test that handle_chat requires login.""" | |
| from app import SessionState, handle_chat | |
| state = SessionState("handle_test_4") | |
| history, msg, new_state = handle_chat( | |
| message="Hello", | |
| history=[], | |
| session_state=state, | |
| profile=None, # Not logged in | |
| oauth_token=None, | |
| ) | |
| # Should prompt to sign in | |
| assert len(history) >= 2 | |
| last_response = history[-1]["content"] | |
| assert "sign in" in last_response.lower() | |
| def test_creates_session_if_none(self): | |
| """Test that handle_chat creates session state if None.""" | |
| from app import handle_chat | |
| mock_profile = MagicMock() | |
| mock_profile.name = "TestUser" | |
| history, msg, new_state = handle_chat( | |
| message="Hello", | |
| history=[], | |
| session_state=None, # No session | |
| profile=mock_profile, | |
| oauth_token=MagicMock(), | |
| ) | |
| # Should have created a session | |
| assert new_state is not None | |
| assert new_state.session_id is not None | |
| class TestGetKnowledgeStatsWithSession: | |
| """Tests for get_knowledge_stats with session state.""" | |
| def test_empty_session_knowledge_base(self): | |
| """Test stats for empty session knowledge base.""" | |
| from app import SessionState, get_knowledge_stats | |
| state = SessionState("stats_test_1") | |
| result = get_knowledge_stats(state) | |
| assert "empty" in result.lower() | |
| def test_populated_session_knowledge_base(self): | |
| """Test stats for populated session knowledge base.""" | |
| from app import SessionState, add_to_vector_db, get_knowledge_stats | |
| state = SessionState("stats_test_2") | |
| add_to_vector_db( | |
| title="Test Video 1", | |
| transcript="Some content here about testing.", | |
| visual_contexts=["Test scene"], | |
| session_state=state, | |
| ) | |
| add_to_vector_db( | |
| title="Test Video 2", | |
| transcript="More content about different things.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| result = get_knowledge_stats(state) | |
| # Should show chunk count and video count | |
| assert "chunks" in result.lower() or "2" in result | |
| assert "Test Video" in result | |
| class TestConversationalFlow: | |
| """Tests for multi-turn conversational interactions with the chatbot.""" | |
| def test_multi_turn_conversation(self): | |
| """Test that chatbot can handle follow-up questions using history.""" | |
| from app import SessionState, add_to_vector_db, handle_chat | |
| state = SessionState("convo_test_1") | |
| # Add content about a cooking video | |
| add_to_vector_db( | |
| title="Italian Cooking", | |
| transcript="Today we make authentic Italian pasta. First boil water. Add salt. " | |
| "Cook pasta for 8 minutes. The sauce uses fresh tomatoes, garlic, and basil.", | |
| visual_contexts=["Chef chopping tomatoes", "Boiling pot of pasta"], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| with patch("app.InferenceClient") as mock_client: | |
| # First question | |
| mock_response1 = MagicMock() | |
| mock_response1.choices = [MagicMock()] | |
| mock_response1.choices[0].message.content = "The video shows how to make Italian pasta with a tomato sauce." | |
| # Follow-up question | |
| mock_response2 = MagicMock() | |
| mock_response2.choices = [MagicMock()] | |
| mock_response2.choices[0].message.content = "The sauce ingredients are fresh tomatoes, garlic, and basil." | |
| mock_client.return_value.chat.completions.create.side_effect = [mock_response1, mock_response2] | |
| # First turn | |
| history1, _, state = handle_chat( | |
| message="What is this video about?", | |
| history=[], | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| first_turn_len = len(history1) | |
| assert first_turn_len >= 2 | |
| assert "pasta" in history1[-1]["content"].lower() or "Italian" in history1[-1]["content"] | |
| # Second turn - follow-up question using history | |
| history2, _, state = handle_chat( | |
| message="What ingredients are in the sauce?", | |
| history=history1, # Pass previous history | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| # Should have more messages now (history is mutated in place) | |
| assert len(history2) == 4 # 2 turns x 2 messages each | |
| # Last response should be about ingredients | |
| assert "tomatoes" in history2[-1]["content"].lower() or "sauce" in history2[-1]["content"].lower() | |
| def test_history_preserves_context(self): | |
| """Test that conversation history preserves context for follow-ups.""" | |
| from app import SessionState, add_to_vector_db, handle_chat | |
| state = SessionState("convo_test_2") | |
| # Add content | |
| add_to_vector_db( | |
| title="Python Tutorial", | |
| transcript="Python is a programming language. Variables store data. " | |
| "Functions are defined with def keyword. Classes use the class keyword.", | |
| visual_contexts=["Code editor showing Python"], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test_token" | |
| with patch("app.InferenceClient") as mock_client: | |
| mock_response = MagicMock() | |
| mock_response.choices = [MagicMock()] | |
| mock_response.choices[0].message.content = "Functions are defined using the def keyword." | |
| mock_client.return_value.chat.completions.create.return_value = mock_response | |
| # Build up a conversation | |
| history = [] | |
| # Turn 1: Ask about functions | |
| history, _, state = handle_chat( | |
| message="How do you define functions in Python?", | |
| history=history, | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| # Verify history structure | |
| assert len(history) == 2 # User + Assistant | |
| assert history[0]["role"] == "user" | |
| assert history[1]["role"] == "assistant" | |
| assert "function" in history[0]["content"].lower() | |
| def test_user_messages_added_to_history(self): | |
| """Test that user messages are properly added to history.""" | |
| from app import SessionState, handle_chat | |
| state = SessionState("convo_test_3") | |
| mock_profile = MagicMock() | |
| history, _, state = handle_chat( | |
| message="Hello chatbot!", | |
| history=[], | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=MagicMock(), | |
| ) | |
| # User message should be in history | |
| user_messages = [h for h in history if h["role"] == "user"] | |
| assert len(user_messages) >= 1 | |
| assert user_messages[0]["content"] == "Hello chatbot!" | |
| def test_assistant_responses_added_to_history(self): | |
| """Test that assistant responses are properly added to history.""" | |
| from app import SessionState, add_to_vector_db, handle_chat | |
| state = SessionState("convo_test_4") | |
| add_to_vector_db( | |
| title="Test", | |
| transcript="Test content.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test" | |
| with patch("app.InferenceClient") as mock_client: | |
| mock_response = MagicMock() | |
| mock_response.choices = [MagicMock()] | |
| mock_response.choices[0].message.content = "This is my response." | |
| mock_client.return_value.chat.completions.create.return_value = mock_response | |
| history, _, _ = handle_chat( | |
| message="Tell me about the test", | |
| history=[], | |
| session_state=state, | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| ) | |
| # Assistant message should be in history | |
| assistant_messages = [h for h in history if h["role"] == "assistant"] | |
| assert len(assistant_messages) >= 1 | |
| def test_can_ask_about_specific_parts(self): | |
| """Test asking specific questions about video content.""" | |
| from app import SessionState, add_to_vector_db, search_knowledge | |
| state = SessionState("specific_test") | |
| # Add detailed content | |
| add_to_vector_db( | |
| title="Science Documentary", | |
| transcript="The documentary covers three topics. First, black holes are massive objects. " | |
| "Second, neutron stars are extremely dense. Third, galaxies contain billions of stars.", | |
| visual_contexts=[ | |
| "Animation of black hole", | |
| "Diagram of neutron star", | |
| "Hubble image of galaxy", | |
| ], | |
| session_state=state, | |
| ) | |
| # Search for specific topic | |
| results = search_knowledge("black holes", session_state=state) | |
| assert len(results) > 0 | |
| assert any("black hole" in r["content"].lower() for r in results) | |
| # Search for another topic | |
| results = search_knowledge("neutron stars", session_state=state) | |
| assert len(results) > 0 | |
| assert any("neutron" in r["content"].lower() for r in results) | |
| # Search for visual content | |
| results = search_knowledge("galaxy image", session_state=state) | |
| assert len(results) > 0 | |
| class TestLLMContextPassing: | |
| """Tests to verify correct context is passed to the LLM.""" | |
| def test_context_includes_relevant_video_content(self): | |
| """Test that the LLM receives relevant video content in its prompt.""" | |
| from app import SessionState, add_to_vector_db, chat_with_videos | |
| state = SessionState("context_test_1") | |
| add_to_vector_db( | |
| title="Machine Learning Basics", | |
| transcript="Neural networks consist of layers. Input layer, hidden layers, and output layer.", | |
| visual_contexts=["Diagram of neural network architecture"], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test" | |
| captured_messages = None | |
| with patch("app.InferenceClient") as mock_client: | |
| def capture_call(*args, **kwargs): | |
| nonlocal captured_messages | |
| captured_messages = kwargs.get("messages", []) | |
| mock_resp = MagicMock() | |
| mock_resp.choices = [MagicMock()] | |
| mock_resp.choices[0].message.content = "Response" | |
| return mock_resp | |
| mock_client.return_value.chat.completions.create.side_effect = capture_call | |
| chat_with_videos( | |
| message="Tell me about neural networks", | |
| history=[], | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| session_state=state, | |
| ) | |
| # Verify the context was passed to LLM | |
| assert captured_messages is not None | |
| assert len(captured_messages) == 2 # system + user | |
| # User message should contain the video content | |
| user_msg = captured_messages[1]["content"] | |
| assert "neural" in user_msg.lower() | |
| assert "layers" in user_msg.lower() | |
| def test_system_prompt_instructs_rag_behavior(self): | |
| """Test that system prompt instructs LLM to use provided context.""" | |
| from app import SessionState, add_to_vector_db, chat_with_videos | |
| state = SessionState("context_test_2") | |
| add_to_vector_db( | |
| title="Test", | |
| transcript="Content here.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test" | |
| captured_messages = None | |
| with patch("app.InferenceClient") as mock_client: | |
| def capture_call(*args, **kwargs): | |
| nonlocal captured_messages | |
| captured_messages = kwargs.get("messages", []) | |
| mock_resp = MagicMock() | |
| mock_resp.choices = [MagicMock()] | |
| mock_resp.choices[0].message.content = "Response" | |
| return mock_resp | |
| mock_client.return_value.chat.completions.create.side_effect = capture_call | |
| chat_with_videos( | |
| message="Question", | |
| history=[], | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| session_state=state, | |
| ) | |
| # System prompt should instruct RAG behavior | |
| system_msg = captured_messages[0]["content"] | |
| assert "video" in system_msg.lower() | |
| assert "context" in system_msg.lower() | |
| def test_user_question_included_in_prompt(self): | |
| """Test that the user's actual question is included in the prompt.""" | |
| from app import SessionState, add_to_vector_db, chat_with_videos | |
| state = SessionState("context_test_3") | |
| add_to_vector_db( | |
| title="Test", | |
| transcript="Content.", | |
| visual_contexts=[], | |
| session_state=state, | |
| ) | |
| mock_profile = MagicMock() | |
| mock_token = MagicMock() | |
| mock_token.token = "test" | |
| specific_question = "What are the three main ingredients mentioned?" | |
| captured_messages = None | |
| with patch("app.InferenceClient") as mock_client: | |
| def capture_call(*args, **kwargs): | |
| nonlocal captured_messages | |
| captured_messages = kwargs.get("messages", []) | |
| mock_resp = MagicMock() | |
| mock_resp.choices = [MagicMock()] | |
| mock_resp.choices[0].message.content = "Response" | |
| return mock_resp | |
| mock_client.return_value.chat.completions.create.side_effect = capture_call | |
| chat_with_videos( | |
| message=specific_question, | |
| history=[], | |
| profile=mock_profile, | |
| oauth_token=mock_token, | |
| session_state=state, | |
| ) | |
| # User's question should be in the prompt | |
| user_msg = captured_messages[1]["content"] | |
| assert specific_question in user_msg | |