cryogenic22 commited on
Commit
c15551c
·
verified ·
1 Parent(s): 22c9c7b

Create ai_resilience.py

Browse files
Files changed (1) hide show
  1. ai_resilience.py +208 -0
ai_resilience.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: ai_resilience.py
2
+ # Location: /ai_resilience.py
3
+ # Description: Robust AI API error handling and fallback mechanism
4
+
5
+ import logging
6
+ from typing import List, Dict, Any, Callable
7
+ import openai
8
+ import anthropic
9
+ from tenacity import (
10
+ retry,
11
+ stop_after_attempt,
12
+ wait_exponential,
13
+ retry_if_exception_type
14
+ )
15
+
16
+ class AIResilientClient:
17
+ def __init__(self,
18
+ openai_key: str = None,
19
+ anthropic_key: str = None,
20
+ max_retries: int = 3):
21
+ """
22
+ Initialize resilient AI client with multiple model support
23
+
24
+ Args:
25
+ openai_key (str): OpenAI API key
26
+ anthropic_key (str): Anthropic API key
27
+ max_retries (int): Maximum retry attempts for API calls
28
+ """
29
+ # Configure logging
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(levelname)s - %(message)s'
33
+ )
34
+ self.logger = logging.getLogger(__name__)
35
+
36
+ # Initialize API clients
37
+ self.openai_client = openai.OpenAI(api_key=openai_key) if openai_key else None
38
+ self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) if anthropic_key else None
39
+
40
+ # Fallback models
41
+ self.models = [
42
+ {
43
+ 'provider': 'openai',
44
+ 'model': 'gpt-4-turbo',
45
+ 'temperature': 0.7,
46
+ 'max_tokens': 4096
47
+ },
48
+ {
49
+ 'provider': 'anthropic',
50
+ 'model': 'claude-3-opus-20240229',
51
+ 'temperature': 0.7,
52
+ 'max_tokens': 4096
53
+ },
54
+ {
55
+ 'provider': 'openai',
56
+ 'model': 'gpt-3.5-turbo',
57
+ 'temperature': 0.7,
58
+ 'max_tokens': 4096
59
+ }
60
+ ]
61
+
62
+ @retry(
63
+ stop=stop_after_attempt(3),
64
+ wait=wait_exponential(multiplier=1, min=4, max=10),
65
+ retry=retry_if_exception_type(
66
+ (openai.APIError,
67
+ openai.APIConnectionError,
68
+ openai.RateLimitError,
69
+ anthropic.APIError)
70
+ )
71
+ )
72
+ def generate_content(
73
+ self,
74
+ prompt: str,
75
+ context: Dict[str, Any] = None
76
+ ) -> str:
77
+ """
78
+ Generate content with intelligent fallback and error handling
79
+
80
+ Args:
81
+ prompt (str): Content generation prompt
82
+ context (Dict, optional): Additional context for generation
83
+
84
+ Returns:
85
+ str: Generated content
86
+ """
87
+ # Attempt generation with each configured model
88
+ for model_config in self.models:
89
+ try:
90
+ return self._generate_with_model(model_config, prompt, context)
91
+ except Exception as e:
92
+ self.logger.warning(
93
+ f"Failed with {model_config['provider']} {model_config['model']}: {str(e)}"
94
+ )
95
+
96
+ # If all models fail, raise a critical error
97
+ error_msg = "All AI models failed. Unable to generate content."
98
+ self.logger.critical(error_msg)
99
+ raise RuntimeError(error_msg)
100
+
101
+ def _generate_with_model(
102
+ self,
103
+ model_config: Dict[str, Any],
104
+ prompt: str,
105
+ context: Dict[str, Any] = None
106
+ ) -> str:
107
+ """
108
+ Generate content with a specific model
109
+
110
+ Args:
111
+ model_config (Dict): Model configuration
112
+ prompt (str): Content generation prompt
113
+ context (Dict, optional): Additional context
114
+
115
+ Returns:
116
+ str: Generated content
117
+ """
118
+ provider = model_config['provider']
119
+
120
+ if provider == 'openai' and self.openai_client:
121
+ response = self.openai_client.chat.completions.create(
122
+ model=model_config['model'],
123
+ messages=[
124
+ {"role": "system", "content": "You are a helpful assistant."},
125
+ {"role": "user", "content": prompt}
126
+ ],
127
+ temperature=model_config.get('temperature', 0.7),
128
+ max_tokens=model_config.get('max_tokens', 4096)
129
+ )
130
+ return response.choices[0].message.content
131
+
132
+ elif provider == 'anthropic' and self.anthropic_client:
133
+ response = self.anthropic_client.messages.create(
134
+ model=model_config['model'],
135
+ max_tokens=model_config.get('max_tokens', 4096),
136
+ temperature=model_config.get('temperature', 0.7),
137
+ messages=[
138
+ {"role": "user", "content": prompt}
139
+ ]
140
+ )
141
+ return response.content[0].text
142
+
143
+ raise ValueError(f"No valid client for provider: {provider}")
144
+
145
+ def validate_api_keys(self) -> Dict[str, bool]:
146
+ """
147
+ Validate API keys for configured providers
148
+
149
+ Returns:
150
+ Dict with validation status for each provider
151
+ """
152
+ results = {
153
+ 'openai': False,
154
+ 'anthropic': False
155
+ }
156
+
157
+ # Test OpenAI
158
+ if self.openai_client:
159
+ try:
160
+ self.openai_client.models.list()
161
+ results['openai'] = True
162
+ except Exception:
163
+ pass
164
+
165
+ # Test Anthropic
166
+ if self.anthropic_client:
167
+ try:
168
+ self.anthropic_client.models.list()
169
+ results['anthropic'] = True
170
+ except Exception:
171
+ pass
172
+
173
+ return results
174
+
175
+ # Comprehensive Test Suite
176
+ def run_resilience_tests():
177
+ """
178
+ Run comprehensive tests for AI resilience module
179
+ """
180
+ # Initialize with mock API keys (replace with actual keys for real testing)
181
+ client = AIResilientClient(
182
+ openai_key=os.getenv('OPENAI_API_KEY'),
183
+ anthropic_key=os.getenv('ANTHROPIC_API_KEY')
184
+ )
185
+
186
+ # Test API Key Validation
187
+ api_status = client.validate_api_keys()
188
+ print("API Key Validation:", api_status)
189
+
190
+ # Test Content Generation
191
+ test_prompts = [
192
+ "Write a short paragraph about artificial intelligence.",
193
+ "Explain the concept of quantum computing in simple terms."
194
+ ]
195
+
196
+ for prompt in test_prompts:
197
+ try:
198
+ generated_content = client.generate_content(prompt)
199
+ print(f"\nGenerated Content for Prompt: {prompt}")
200
+ print(generated_content)
201
+ except Exception as e:
202
+ print(f"Content generation failed: {e}")
203
+
204
+ print("\nResilience tests completed.")
205
+
206
+ # Run tests if script is executed directly
207
+ if __name__ == "__main__":
208
+ run_resilience_tests()