Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| from openai import OpenAI | |
| from typing import Dict, Any, Optional | |
| import gradio as gr | |
| from prompts import PROMPT_ANALYZER_TEMPLATE | |
| import time | |
| logger = logging.getLogger(__name__) | |
| FALLBACK_MODELS = [ | |
| "mixtral-8x7b-32768", | |
| "llama-3.1-70b-versatile", | |
| "llama-3.1-8b-instant", | |
| "llama3-70b-8192", | |
| "llama3-8b-8192" | |
| ] | |
| class ModelManager: | |
| def __init__(self): | |
| self.current_model_index = 0 | |
| self.max_retries = len(FALLBACK_MODELS) | |
| def current_model(self) -> str: | |
| return FALLBACK_MODELS[self.current_model_index] | |
| def next_model(self) -> str: | |
| self.current_model_index = (self.current_model_index + 1) % len(FALLBACK_MODELS) | |
| logger.info(f"Switching to model: {self.current_model}") | |
| return self.current_model | |
| class PromptEnhancementAPI: | |
| def __init__(self, api_key: str, base_url: Optional[str] = None): | |
| self.client = OpenAI( | |
| api_key=api_key, | |
| base_url=base_url or "https://api.groq.com/openai/v1" | |
| ) | |
| self.model_manager = ModelManager() | |
| def _try_parse_json(self, content: str, retries: int = 0) -> Dict[str, Any]: | |
| try: | |
| result = json.loads(content.strip().lstrip('\n')) | |
| if not isinstance(result, dict): | |
| raise ValueError("Response is not a valid JSON object") | |
| return result | |
| except (json.JSONDecodeError, ValueError) as e: | |
| if retries < self.model_manager.max_retries - 1: | |
| logger.warning(f"JSON parsing failed with model {self.model_manager.current_model}. Switching models...") | |
| self.model_manager.next_model() | |
| raise e | |
| logger.error(f"JSON parsing failed with all models: {str(e)}") | |
| raise | |
| def generate_enhancement(self, system_prompt: str, user_prompt: str, user_directive: str = "", state: Optional[Dict] = None) -> Dict[str, Any]: | |
| retries = 0 | |
| last_error = None | |
| while retries < self.model_manager.max_retries: | |
| try: | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| if user_directive: | |
| messages.append({"role": "user", "content": f"User directive: {user_directive}"}) | |
| if state: | |
| messages.append({ | |
| "role": "assistant", | |
| "content": json.dumps(state) | |
| }) | |
| response = self.client.chat.completions.create( | |
| model=self.model_manager.current_model, | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=4000, | |
| response_format={"type": "json_object"} | |
| ) | |
| result = self._try_parse_json(response.choices[0].message.content, retries) | |
| return result | |
| except (json.JSONDecodeError, ValueError) as e: | |
| last_error = e | |
| retries += 1 | |
| if retries < self.model_manager.max_retries: | |
| logger.warning(f"Attempt {retries} failed. Switching models and retrying...") | |
| time.sleep(1) # Brief pause before retry | |
| continue | |
| break | |
| except Exception as e: | |
| logger.error(f"API error: {str(e)}") | |
| if "rate limit" in str(e).lower(): | |
| if retries < self.model_manager.max_retries - 1: | |
| self.model_manager.next_model() | |
| retries += 1 | |
| time.sleep(1) | |
| continue | |
| raise gr.Error(f"API request failed: {str(e)}") | |
| logger.error(f"All models failed to generate valid JSON: {str(last_error)}") | |
| return create_error_response(user_prompt, user_directive) | |
| class PromptEnhancementSystem: | |
| def __init__(self, api_key: str, base_url: Optional[str] = None): | |
| self.api = PromptEnhancementAPI(api_key, base_url) | |
| self.current_state = None | |
| self.history = [] | |
| def start_session(self, prompt: str, user_directive: str = "") -> Dict[str, Any]: | |
| formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format( | |
| input_prompt=prompt, | |
| user_directive=user_directive | |
| ) | |
| result = self.api.generate_enhancement( | |
| system_prompt=formatted_system_prompt, | |
| user_prompt=prompt, | |
| user_directive=user_directive | |
| ) | |
| self.current_state = result | |
| self.history = [result] | |
| return result | |
| def apply_enhancement(self, choice: str, user_directive: str = "") -> Dict[str, Any]: | |
| formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format( | |
| input_prompt=choice, | |
| user_directive=user_directive | |
| ) | |
| result = self.api.generate_enhancement( | |
| system_prompt=formatted_system_prompt, | |
| user_prompt=choice, | |
| user_directive=user_directive, | |
| state=self.current_state | |
| ) | |
| self.current_state = result | |
| self.history.append(result) | |
| return result |