| import asyncio |
| import json |
| import logging |
| from typing import TypeVar, Type, Optional, Callable |
| from pydantic import BaseModel |
| from langchain_mistralai.chat_models import ChatMistralAI |
| from langchain.schema import SystemMessage, HumanMessage |
| from langchain.schema.messages import BaseMessage |
|
|
| T = TypeVar('T', bound=BaseModel) |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class MistralAPIError(Exception): |
| """Base class for Mistral API errors""" |
| pass |
|
|
| class MistralRateLimitError(MistralAPIError): |
| """Raised when hitting rate limits""" |
| pass |
|
|
| class MistralParsingError(MistralAPIError): |
| """Raised when response parsing fails""" |
| pass |
|
|
| class MistralValidationError(MistralAPIError): |
| """Raised when response validation fails""" |
| pass |
|
|
| class MistralClient: |
| def __init__(self, api_key: str, model_name: str = "mistral-small-latest", max_tokens: int = 1000): |
| logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}") |
| self.model = ChatMistralAI( |
| mistral_api_key=api_key, |
| model=model_name, |
| max_tokens=max_tokens |
| ) |
| self.fixing_model = ChatMistralAI( |
| mistral_api_key=api_key, |
| model=model_name, |
| max_tokens=max_tokens |
| ) |
| |
| |
| self.last_call_time = 0 |
| self.min_delay = 1 |
| self.max_retries = 5 |
| self.backoff_factor = 2 |
| self.max_backoff = 30 |
| |
| async def _wait_for_rate_limit(self): |
| """Attend le temps nécessaire pour respecter le rate limit.""" |
| current_time = asyncio.get_event_loop().time() |
| time_since_last_call = current_time - self.last_call_time |
| |
| if time_since_last_call < self.min_delay: |
| delay = self.min_delay - time_since_last_call |
| logger.debug(f"Rate limit: waiting for {delay:.2f} seconds") |
| await asyncio.sleep(delay) |
| |
| self.last_call_time = asyncio.get_event_loop().time() |
|
|
| async def _handle_api_error(self, error: Exception, retry_count: int) -> float: |
| """Handle API errors and return wait time for retry""" |
| wait_time = min(self.backoff_factor ** retry_count, self.max_backoff) |
| |
| if "rate limit" in str(error).lower(): |
| logger.warning(f"Rate limit hit, waiting {wait_time}s before retry") |
| raise MistralRateLimitError(str(error)) |
| elif "403" in str(error): |
| logger.error("Authentication error - invalid API key or quota exceeded") |
| raise MistralAPIError("Authentication failed") |
| |
| return wait_time |
|
|
| async def _generate_with_retry( |
| self, |
| messages: list[BaseMessage], |
| response_model: Optional[Type[T]] = None, |
| custom_parser: Optional[Callable[[str], T]] = None, |
| error_feedback: str = None |
| ) -> T | str: |
| retry_count = 0 |
| last_error = None |
| |
| while retry_count < self.max_retries: |
| try: |
| logger.info(f"Attempt {retry_count + 1}/{self.max_retries}") |
| |
| current_messages = messages.copy() |
| if error_feedback and retry_count > 0: |
| if isinstance(last_error, MistralParsingError): |
| |
| current_messages.append(HumanMessage(content="Please ensure your response is in valid JSON format.")) |
| elif isinstance(last_error, MistralValidationError): |
| |
| current_messages.append(HumanMessage(content=f"Previous error: {error_feedback}. Please try again.")) |
| |
| await self._wait_for_rate_limit() |
| try: |
| response = await self.model.ainvoke(current_messages) |
| content = response.content |
| logger.debug(f"Raw response: {content[:100]}...") |
| except Exception as api_error: |
| wait_time = await self._handle_api_error(api_error, retry_count) |
| retry_count += 1 |
| if retry_count < self.max_retries: |
| await asyncio.sleep(wait_time) |
| continue |
| raise |
|
|
| |
| if not response_model and not custom_parser: |
| return content |
|
|
| |
| try: |
| if custom_parser: |
| return custom_parser(content) |
| |
| |
| data = json.loads(content) |
| return response_model(**data) |
| except json.JSONDecodeError as e: |
| last_error = MistralParsingError(f"Invalid JSON format: {str(e)}") |
| logger.error(f"JSON parsing error: {str(e)}") |
| raise last_error |
| except Exception as e: |
| last_error = MistralValidationError(str(e)) |
| logger.error(f"Validation error: {str(e)}") |
| raise last_error |
|
|
| except (MistralParsingError, MistralValidationError) as e: |
| logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}") |
| last_error = e |
| retry_count += 1 |
| if retry_count < self.max_retries: |
| wait_time = min(self.backoff_factor ** retry_count, self.max_backoff) |
| logger.info(f"Waiting {wait_time} seconds before retry...") |
| await asyncio.sleep(wait_time) |
| continue |
| |
| logger.error(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}") |
| raise Exception(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}") |
| |
| async def generate(self, messages: list[BaseMessage], response_model: Optional[Type[T]] = None, custom_parser: Optional[Callable[[str], T]] = None) -> T | str: |
| """Génère une réponse à partir d'une liste de messages avec parsing optionnel.""" |
| return await self._generate_with_retry(messages, response_model, custom_parser) |
|
|
| async def transform_prompt(self, story_text: str, art_prompt: str) -> str: |
| """Transforme un texte d'histoire en prompt artistique.""" |
| messages = [{ |
| "role": "system", |
| "content": art_prompt |
| }, { |
| "role": "user", |
| "content": f"Transform this story text into a comic panel description:\n{story_text}" |
| }] |
| try: |
| return await self._generate_with_retry(messages) |
| except Exception as e: |
| print(f"Error transforming prompt: {str(e)}") |
| return story_text |
|
|
| async def generate_text(self, messages: list[BaseMessage]) -> str: |
| """ |
| Génère une réponse textuelle simple sans structure JSON. |
| Utile pour la génération de texte narratif ou descriptif. |
| |
| Args: |
| messages: Liste des messages pour le modèle |
| |
| Returns: |
| str: Le texte généré |
| """ |
| retry_count = 0 |
| last_error = None |
| |
| while retry_count < self.max_retries: |
| try: |
| logger.info(f"Attempt {retry_count + 1}/{self.max_retries}") |
| |
| await self._wait_for_rate_limit() |
| response = await self.model.ainvoke(messages) |
| return response.content.strip() |
| |
| except Exception as e: |
| logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}") |
| retry_count += 1 |
| if retry_count < self.max_retries: |
| wait_time = 2 * retry_count |
| logger.info(f"Waiting {wait_time} seconds before retry...") |
| await asyncio.sleep(wait_time) |
| continue |
| |
| logger.error(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") |
| raise Exception(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") |
|
|
| async def check_health(self) -> bool: |
| """ |
| Vérifie la disponibilité du service Mistral avec un appel simple sans retry. |
| |
| Returns: |
| bool: True si le service est disponible, False sinon |
| """ |
| try: |
| response = await self.model.ainvoke([SystemMessage(content="Hi")]) |
| return True |
| except Exception as e: |
| logger.error(f"Health check failed: {str(e)}") |
| raise |