|
|
""" |
|
|
Service LLM pour l'intégration avec Groq et autres fournisseurs. |
|
|
Gère les appels aux modèles de langage pour le résumé et l'analyse. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import aiohttp |
|
|
import json |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from datetime import datetime |
|
|
import time |
|
|
|
|
|
from config.settings import api_config |
|
|
from src.core.logging import setup_logger |
|
|
import traceback |
|
|
|
|
|
|
|
|
class LLMError(Exception): |
|
|
"""Exception pour les erreurs LLM.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class LLMRateLimitError(LLMError): |
|
|
"""Exception pour les erreurs de limite de taux.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class LLMService: |
|
|
""" |
|
|
Service pour les appels aux modèles de langage. |
|
|
|
|
|
Fonctionnalités: |
|
|
- Support de Groq API |
|
|
- Gestion des limites de taux |
|
|
- Retry automatique avec backoff |
|
|
- Streaming optionnel |
|
|
- Validation des réponses |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.config = api_config |
|
|
self.logger = setup_logger("llm_service") |
|
|
|
|
|
|
|
|
self.groq_api_key = self.config.GROQ_API_KEY |
|
|
self.groq_base_url = "https://api.groq.com/openai/v1" |
|
|
self.default_model = getattr(self.config, 'GROQ_MODEL', "llama-3.1-8b-instant") |
|
|
|
|
|
|
|
|
self.rate_limit_requests = 30 |
|
|
self.rate_limit_tokens = 6000 |
|
|
self.request_timestamps = [] |
|
|
|
|
|
|
|
|
self.default_params = { |
|
|
"temperature": 0.3, |
|
|
"max_tokens": 2000, |
|
|
"top_p": 0.9, |
|
|
"frequency_penalty": 0.1, |
|
|
"presence_penalty": 0.1 |
|
|
} |
|
|
|
|
|
|
|
|
self.headers = { |
|
|
"Authorization": f"Bearer {self.groq_api_key}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
async def generate_completion( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: Optional[str] = None, |
|
|
model: Optional[str] = None, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Génère une complétion de texte. |
|
|
|
|
|
Args: |
|
|
prompt: Prompt utilisateur |
|
|
system_prompt: Prompt système optionnel |
|
|
model: Modèle à utiliser (défaut: config) |
|
|
**kwargs: Paramètres supplémentaires pour l'API |
|
|
|
|
|
Returns: |
|
|
Réponse générée par le modèle |
|
|
|
|
|
Raises: |
|
|
LLMError: En cas d'erreur API |
|
|
LLMRateLimitError: En cas de dépassement de limite |
|
|
""" |
|
|
|
|
|
messages = [] |
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
params = {**self.default_params, **kwargs} |
|
|
payload = { |
|
|
"model": model or self.default_model, |
|
|
"messages": messages, |
|
|
**params |
|
|
} |
|
|
|
|
|
|
|
|
await self._check_rate_limits() |
|
|
|
|
|
|
|
|
return await self._make_api_call(payload) |
|
|
|
|
|
async def generate_batch_completions( |
|
|
self, |
|
|
prompts: List[str], |
|
|
system_prompt: Optional[str] = None, |
|
|
model: Optional[str] = None, |
|
|
max_concurrent: int = 3, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Génère plusieurs complétions en parallèle. |
|
|
|
|
|
Args: |
|
|
prompts: Liste des prompts |
|
|
system_prompt: Prompt système optionnel |
|
|
model: Modèle à utiliser |
|
|
max_concurrent: Nombre maximum de requêtes simultanées |
|
|
**kwargs: Paramètres supplémentaires |
|
|
|
|
|
Returns: |
|
|
Liste des réponses dans le même ordre que les prompts |
|
|
""" |
|
|
self.logger.info(f"Génération batch de {len(prompts)} complétions") |
|
|
|
|
|
|
|
|
semaphore = asyncio.Semaphore(max_concurrent) |
|
|
|
|
|
async def generate_single(prompt: str, index: int) -> tuple: |
|
|
async with semaphore: |
|
|
try: |
|
|
|
|
|
await asyncio.sleep(index * 0.5) |
|
|
|
|
|
result = await self.generate_completion( |
|
|
prompt, system_prompt, model, **kwargs |
|
|
) |
|
|
return index, result |
|
|
except Exception as e: |
|
|
self.logger.error(f"Erreur completion {index}: {e}") |
|
|
return index, f"ERREUR: {str(e)}" |
|
|
|
|
|
|
|
|
tasks = [generate_single(prompt, i) for i, prompt in enumerate(prompts)] |
|
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
ordered_results = [""] * len(prompts) |
|
|
for result in results: |
|
|
if isinstance(result, tuple): |
|
|
index, content = result |
|
|
ordered_results[index] = content |
|
|
else: |
|
|
|
|
|
ordered_results.append(f"EXCEPTION: {str(result)}") |
|
|
|
|
|
success_count = sum(1 for r in ordered_results if not r.startswith("ERREUR")) |
|
|
self.logger.info(f"Batch terminé: {success_count}/{len(prompts)} succès") |
|
|
|
|
|
return ordered_results |
|
|
|
|
|
async def _make_api_call(self, payload: Dict[str, Any], max_retries: int = 3) -> str: |
|
|
"""Effectue l'appel API avec retry automatique.""" |
|
|
url = f"{self.groq_base_url}/chat/completions" |
|
|
|
|
|
for attempt in range(max_retries + 1): |
|
|
try: |
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60)) as session: |
|
|
async with session.post(url, json=payload, headers=self.headers) as response: |
|
|
|
|
|
|
|
|
self.request_timestamps.append(time.time()) |
|
|
|
|
|
if response.status == 200: |
|
|
data = await response.json() |
|
|
content = data["choices"][0]["message"]["content"] |
|
|
|
|
|
|
|
|
if not content or content.strip() == "": |
|
|
raise LLMError("Réponse vide du modèle") |
|
|
|
|
|
return content.strip() |
|
|
|
|
|
elif response.status == 429: |
|
|
|
|
|
retry_after = int(response.headers.get("retry-after", 60)) |
|
|
self.logger.warning(f"Rate limit atteint, attente {retry_after}s") |
|
|
|
|
|
if attempt < max_retries: |
|
|
await asyncio.sleep(retry_after) |
|
|
continue |
|
|
else: |
|
|
raise LLMRateLimitError("Limite de taux API dépassée") |
|
|
|
|
|
else: |
|
|
|
|
|
error_text = await response.text() |
|
|
error_msg = f"Erreur API {response.status}: {error_text}" |
|
|
|
|
|
if attempt < max_retries: |
|
|
self.logger.warning(f"{error_msg} - Tentative {attempt + 1}/{max_retries}") |
|
|
await asyncio.sleep(2 ** attempt) |
|
|
continue |
|
|
else: |
|
|
raise LLMError(error_msg) |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
if attempt < max_retries: |
|
|
self.logger.warning(f"Timeout API - Tentative {attempt + 1}/{max_retries}") |
|
|
await asyncio.sleep(2 ** attempt) |
|
|
continue |
|
|
else: |
|
|
raise LLMError("Timeout API après plusieurs tentatives") |
|
|
|
|
|
except Exception as e: |
|
|
if attempt < max_retries: |
|
|
self.logger.warning(f"Erreur réseau: {e} - Tentative {attempt + 1}/{max_retries}") |
|
|
await asyncio.sleep(2 ** attempt) |
|
|
continue |
|
|
else: |
|
|
raise LLMError(f"Erreur de connexion: {str(e)}") |
|
|
|
|
|
raise LLMError("Toutes les tentatives ont échoué") |
|
|
|
|
|
async def _check_rate_limits(self): |
|
|
"""Vérifie et applique les limites de taux.""" |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
self.request_timestamps = [ |
|
|
ts for ts in self.request_timestamps |
|
|
if current_time - ts < 60 |
|
|
] |
|
|
|
|
|
|
|
|
if len(self.request_timestamps) >= self.rate_limit_requests: |
|
|
oldest_request = min(self.request_timestamps) |
|
|
wait_time = 60 - (current_time - oldest_request) |
|
|
|
|
|
if wait_time > 0: |
|
|
self.logger.info(f"Rate limit: attente {wait_time:.1f}s") |
|
|
await asyncio.sleep(wait_time) |
|
|
|
|
|
def estimate_tokens(self, text: str) -> int: |
|
|
"""Estime le nombre de tokens dans un texte.""" |
|
|
|
|
|
return len(text) // 4 |
|
|
|
|
|
def validate_input_length(self, text: str, max_tokens: int = 6000) -> bool: |
|
|
"""Valide que le texte ne dépasse pas la limite de tokens.""" |
|
|
estimated_tokens = self.estimate_tokens(text) |
|
|
return estimated_tokens <= max_tokens |
|
|
|
|
|
def truncate_text(self, text: str, max_tokens: int = 6000) -> str: |
|
|
"""Tronque un texte pour respecter la limite de tokens.""" |
|
|
estimated_tokens = self.estimate_tokens(text) |
|
|
|
|
|
if estimated_tokens <= max_tokens: |
|
|
return text |
|
|
|
|
|
|
|
|
ratio = max_tokens / estimated_tokens |
|
|
target_length = int(len(text) * ratio * 0.9) |
|
|
|
|
|
|
|
|
sentences = text.split('. ') |
|
|
truncated = "" |
|
|
|
|
|
for sentence in sentences: |
|
|
if len(truncated) + len(sentence) + 2 <= target_length: |
|
|
truncated += sentence + ". " |
|
|
else: |
|
|
break |
|
|
|
|
|
self.logger.info(f"Texte tronqué: {len(text)} → {len(truncated)} caractères") |
|
|
return truncated.strip() |
|
|
|
|
|
async def test_connection(self) -> bool: |
|
|
"""Teste la connexion à l'API.""" |
|
|
try: |
|
|
result = await self.generate_completion( |
|
|
"Test de connexion. Réponds juste 'OK'.", |
|
|
system_prompt="Tu es un assistant de test." |
|
|
) |
|
|
|
|
|
if "ok" in result.lower(): |
|
|
self.logger.info("Test de connexion LLM réussi") |
|
|
return True |
|
|
else: |
|
|
self.logger.warning(f"Test de connexion étrange: {result}") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Test de connexion LLM échoué: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
class LLMManager: |
|
|
""" |
|
|
Gestionnaire de services LLM avec stratégies multiples. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.logger = setup_logger("llm_manager") |
|
|
self.primary_service = LLMService() |
|
|
self.services = { |
|
|
"groq": self.primary_service |
|
|
} |
|
|
|
|
|
async def get_completion( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: Optional[str] = None, |
|
|
service: str = "groq", |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Obtient une complétion en utilisant le service spécifié. |
|
|
|
|
|
Args: |
|
|
prompt: Prompt utilisateur |
|
|
system_prompt: Prompt système |
|
|
service: Service LLM à utiliser |
|
|
**kwargs: Paramètres supplémentaires |
|
|
|
|
|
Returns: |
|
|
Réponse du modèle |
|
|
""" |
|
|
if service not in self.services: |
|
|
raise ValueError(f"Service LLM inconnu: {service}") |
|
|
|
|
|
llm_service = self.services[service] |
|
|
return await llm_service.generate_completion(prompt, system_prompt, **kwargs) |
|
|
|
|
|
async def get_batch_completions( |
|
|
self, |
|
|
prompts: List[str], |
|
|
system_prompt: Optional[str] = None, |
|
|
service: str = "groq", |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
"""Obtient des complétions en batch.""" |
|
|
if service not in self.services: |
|
|
raise ValueError(f"Service LLM inconnu: {service}") |
|
|
|
|
|
llm_service = self.services[service] |
|
|
return await llm_service.generate_batch_completions( |
|
|
prompts, system_prompt, **kwargs |
|
|
) |
|
|
|
|
|
async def test_all_services(self) -> Dict[str, bool]: |
|
|
"""Teste tous les services LLM disponibles.""" |
|
|
results = {} |
|
|
|
|
|
for name, service in self.services.items(): |
|
|
try: |
|
|
results[name] = await service.test_connection() |
|
|
except Exception as e: |
|
|
self.logger.error(f"Test service {name} échoué: {e}") |
|
|
results[name] = False |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
async def example_usage(): |
|
|
"""Exemple d'utilisation du service LLM.""" |
|
|
|
|
|
|
|
|
print("=== Test de connexion ===") |
|
|
llm_service = LLMService() |
|
|
|
|
|
connection_ok = await llm_service.test_connection() |
|
|
print(f"Connexion LLM: {'✓ OK' if connection_ok else '✗ Échec'}") |
|
|
|
|
|
if not connection_ok: |
|
|
print("Impossible de continuer sans connexion") |
|
|
return |
|
|
|
|
|
|
|
|
print("\n=== Génération simple ===") |
|
|
try: |
|
|
response = await llm_service.generate_completion( |
|
|
prompt="Explique-moi en 2 phrases ce qu'est l'intelligence artificielle.", |
|
|
system_prompt="Tu es un expert en IA qui explique simplement." |
|
|
) |
|
|
print(f"Réponse: {response}") |
|
|
except Exception as e: |
|
|
print(f"Erreur: {e}") |
|
|
|
|
|
|
|
|
print("\n=== Génération avec paramètres ===") |
|
|
try: |
|
|
response = await llm_service.generate_completion( |
|
|
prompt="Écris un haiku sur la technologie.", |
|
|
system_prompt="Tu es un poète spécialisé dans les haikus.", |
|
|
temperature=0.8, |
|
|
max_tokens=100 |
|
|
) |
|
|
print(f"Haiku: {response}") |
|
|
except Exception as e: |
|
|
print(f"Erreur: {e}") |
|
|
|
|
|
|
|
|
print("\n=== Génération en batch ===") |
|
|
prompts = [ |
|
|
"Qu'est-ce que Python?", |
|
|
"Qu'est-ce que JavaScript?", |
|
|
"Qu'est-ce que Rust?" |
|
|
] |
|
|
|
|
|
try: |
|
|
responses = await llm_service.generate_batch_completions( |
|
|
prompts=prompts, |
|
|
system_prompt="Réponds en une phrase courte.", |
|
|
max_concurrent=2 |
|
|
) |
|
|
|
|
|
for i, (prompt, response) in enumerate(zip(prompts, responses)): |
|
|
print(f"{i+1}. {prompt}") |
|
|
print(f" → {response}\n") |
|
|
except Exception as e: |
|
|
print(f"Erreur batch: {e}") |
|
|
|
|
|
|
|
|
print("\n=== Test des utilitaires ===") |
|
|
long_text = "Ceci est un texte très long. " * 1000 |
|
|
print(f"Texte original: {len(long_text)} caractères") |
|
|
print(f"Tokens estimés: {llm_service.estimate_tokens(long_text)}") |
|
|
|
|
|
is_valid = llm_service.validate_input_length(long_text, max_tokens=7000) |
|
|
print(f"Texte valide (7000 tokens max): {is_valid}") |
|
|
|
|
|
if not is_valid: |
|
|
truncated = llm_service.truncate_text(long_text, max_tokens=7000) |
|
|
print(f"Texte tronqué: {len(truncated)} caractères") |
|
|
print(f"Contenu: {truncated[:200]}...") |
|
|
|
|
|
|
|
|
async def example_manager_usage(): |
|
|
"""Exemple d'utilisation du gestionnaire LLM.""" |
|
|
|
|
|
print("\n=== Test du gestionnaire LLM ===") |
|
|
|
|
|
manager = LLMManager() |
|
|
|
|
|
|
|
|
service_status = await manager.test_all_services() |
|
|
print("État des services:") |
|
|
for service, status in service_status.items(): |
|
|
print(f" {service}: {'✓' if status else '✗'}") |
|
|
|
|
|
|
|
|
try: |
|
|
response = await manager.get_completion( |
|
|
prompt="Salut! Comment ça va?", |
|
|
system_prompt="Tu es un assistant amical.", |
|
|
service="groq" |
|
|
) |
|
|
print(f"\nRéponse du gestionnaire: {response}") |
|
|
except Exception as e: |
|
|
print(f"Erreur gestionnaire: {e}") |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Fonction principale de test.""" |
|
|
try: |
|
|
await example_usage() |
|
|
await example_manager_usage() |
|
|
except KeyboardInterrupt: |
|
|
print("\n\nTest interrompu par l'utilisateur") |
|
|
except Exception as e: |
|
|
print(f"\nErreur inattendue: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Démarrage du test du service LLM...") |
|
|
asyncio.run(main()) |