Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from unittest.mock import MagicMock, patch | |
| # Mocking modules | |
| sys.modules['cv2'] = MagicMock() | |
| sys.modules['whisper'] = MagicMock() | |
| # Set dummy env vars | |
| os.environ["OPENROUTER_API_KEY"] = "dummy" | |
| os.environ["GOOGLE_API_KEY"] = "dummy" | |
| os.environ["GROQ_API_KEY"] = "dummy" | |
| os.environ["NVIDIA_API_KEY"] = "dummy" | |
| os.environ["VERCEL_API_KEY"] = "dummy" | |
| sys.path.append(os.getcwd()) | |
| import agent | |
| from langchain_core.messages import HumanMessage | |
| def test_gemini_alternatives_on_rate_limit(): | |
| print("Testing Gemini alternatives on rate limit...") | |
| # We need to mock ChatGoogleGenerativeAI to simulate rate limit on one instance but success on another | |
| # Since they are created inside smart_invoke, we patch the class constructor or just the instances if we can | |
| with patch('agent.openrouter_model.invoke') as mock_openrouter, \ | |
| patch('agent.ChatGoogleGenerativeAI') as mock_gemini_class: | |
| # OpenRouter fails | |
| mock_openrouter.side_effect = Exception("Rate limit (429)") | |
| # First Gemini call (primary) fails with rate limit | |
| # Second Gemini call (alternative) succeeds | |
| mock_primary = MagicMock() | |
| mock_primary.invoke.side_effect = Exception("Rate limit (429)") | |
| mock_primary.model = "gemini-2.5-flash" | |
| mock_alt = MagicMock() | |
| mock_alt.invoke.return_value = MagicMock(content="Gemini alternative response") | |
| mock_alt.model = "gemini-2.5-flash-lite" | |
| # Control the sequence of ChatGoogleGenerativeAI creation | |
| # agent.py creates gemini_model at top level, then potentially more in smart_invoke | |
| mock_gemini_class.side_effect = [mock_alt] # The one created in the loop | |
| # We also need to mock the already created gemini_model | |
| with patch('agent.gemini_model', mock_primary): | |
| msgs = [HumanMessage(content="Hello")] | |
| response, tier_idx = agent.smart_invoke(msgs, use_tools=False) | |
| print(f"Response from tier {tier_idx}: {response.content}") | |
| # Tier 1 is Gemini | |
| assert tier_idx == 1 | |
| assert response.content == "Gemini alternative response" | |
| print("Gemini alternative on rate limit successful!") | |
| def test_nvidia_name(): | |
| print("Checking NVIDIA model name...") | |
| assert agent.nvidia_model.model_name == "meta/llama-3.1-405b-instruct" | |
| print("NVIDIA model name is correct!") | |
| if __name__ == "__main__": | |
| try: | |
| test_gemini_alternatives_on_rate_limit() | |
| test_nvidia_name() | |
| print("All fix tests passed!") | |
| except Exception as e: | |
| print(f"Test failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |