File size: 4,788 Bytes
ff6af76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""
Test script to verify Grok API configuration before deployment
"""
import os
import sys
from pathlib import Path

# Add current directory to path
sys.path.insert(0, str(Path(__file__).parent))

def test_configuration():
    """Test environment configuration"""
    print("πŸ§ͺ Testing Configuration")
    print("=" * 50)
    
    # Check provider
    provider = os.getenv("LLM_PROVIDER", "openai")
    print(f"βœ“ Provider: {provider}")
    
    # Check API keys
    if provider == "grok":
        api_key = os.getenv("XAI_API_KEY") or os.getenv("GROK_API_KEY")
        key_var = "XAI_API_KEY or GROK_API_KEY"
        default_model = "grok-beta"
        default_url = "https://api.x.ai/v1"
    else:
        api_key = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
        key_var = "OPENAI_API_KEY or HF_TOKEN"
        default_model = "gpt-4o-mini"
        default_url = "https://api.openai.com/v1"
    
    if api_key:
        masked = api_key[:8] + "..." + api_key[-4:] if len(api_key) > 12 else "***"
        print(f"βœ“ API Key ({key_var}): {masked}")
    else:
        print(f"βœ— API Key ({key_var}): NOT SET")
        return False
    
    # Check model
    model = os.getenv("MODEL_NAME", default_model)
    print(f"βœ“ Model: {model}")
    
    # Check API URL
    api_url = os.getenv("API_BASE_URL", default_url)
    print(f"βœ“ API URL: {api_url}")
    
    print("=" * 50)
    return True

def test_api_connection():
    """Test actual API connection"""
    print("\nπŸ”Œ Testing API Connection")
    print("=" * 50)
    
    try:
        from openai import OpenAI
        from inference import API_BASE_URL, MODEL_NAME, API_KEY, PROVIDER
        
        if not API_KEY:
            print("βœ— No API key configured")
            return False
        
        client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
        
        print(f"Testing {PROVIDER} API...")
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Say 'Hello' if you can read this."}
            ],
            max_tokens=10,
            temperature=0
        )
        
        result = response.choices[0].message.content
        print(f"βœ“ API Response: {result}")
        print("βœ“ Connection successful!")
        return True
        
    except Exception as e:
        print(f"βœ— Connection failed: {e}")
        return False

def test_environment_loading():
    """Test environment file loading"""
    print("\nπŸ“ Testing Environment Loading")
    print("=" * 50)
    
    from content_moderation_env import ContentModerationEnv
    
    try:
        env = ContentModerationEnv("moderation_benchmark.json", seed=42)
        print(f"βœ“ Environment loaded: {env.num_scenarios} scenarios")
        
        state = env.reset()
        print(f"βœ“ Reset successful: scenario loaded")
        
        result = env.step({"label": "safe", "action": "allow"})
        print(f"βœ“ Step successful: reward = {result['reward']:.2f}")
        
        return True
    except Exception as e:
        print(f"βœ— Environment test failed: {e}")
        return False

def main():
    """Run all tests"""
    print("\n" + "πŸš€ ContentModerationEnv - Pre-Deployment Tests ".center(50, "="))
    print()
    
    results = []
    
    # Test 1: Configuration
    results.append(("Configuration", test_configuration()))
    
    # Test 2: Environment
    results.append(("Environment", test_environment_loading()))
    
    # Test 3: API Connection (only if config passed)
    if results[0][1]:
        results.append(("API Connection", test_api_connection()))
    else:
        print("\n⚠️  Skipping API test (configuration incomplete)")
    
    # Summary
    print("\n" + "πŸ“Š Test Summary ".center(50, "="))
    all_passed = True
    for name, passed in results:
        status = "βœ… PASS" if passed else "❌ FAIL"
        print(f"{status} - {name}")
        all_passed = all_passed and passed
    
    print("=" * 50)
    
    if all_passed:
        print("\nπŸŽ‰ All tests passed! Ready to deploy.")
        print("\nNext steps:")
        print("  1. Run: ./deploy.sh YOUR_USERNAME SPACE_NAME")
        print("  2. Or manually: git push to your HF Space")
        print("  3. Set secrets in HF Space Settings")
        return 0
    else:
        print("\n⚠️  Some tests failed. Please fix issues before deploying.")
        print("\nTroubleshooting:")
        print("  1. Set environment variables (see .env.example)")
        print("  2. Verify API key is valid")
        print("  3. Check network connectivity")
        return 1

if __name__ == "__main__":
    sys.exit(main())