testpush / src /models /state_models.py
Bachir00's picture
source code
8a848a5
"""
Modèles d'état pour l'orchestration LangGraph.
Définit l'état global du système et les états des agents.
"""
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
from .research_models import ResearchQuery, ResearchOutput
from .document_models import SummarizationOutput
from .report_models import ReportOutput
class AgentType(str, Enum):
"""Types d'agents dans le système."""
RESEARCHER = "researcher"
CONTENT_EXTRACTOR = "content_extractor"
READER = "reader"
WRITER = "writer"
class AgentStatus(str, Enum):
"""Statuts possibles d'un agent."""
IDLE = "idle"
WORKING = "working"
COMPLETED = "completed"
ERROR = "error"
TIMEOUT = "timeout"
class ProcessingStep(str, Enum):
"""Étapes du processus de recherche."""
INIT = "init"
RESEARCH = "research"
READING = "reading"
WRITING = "writing"
COMPLETED = "completed"
ERROR = "error"
class AgentState(BaseModel):
"""
État individuel d'un agent.
"""
agent_type: AgentType = Field(..., description="Type de l'agent")
status: AgentStatus = Field(default=AgentStatus.IDLE, description="Statut actuel")
# Informations de timing
start_time: Optional[datetime] = Field(default=None, description="Heure de début d'exécution")
end_time: Optional[datetime] = Field(default=None, description="Heure de fin d'exécution")
duration: Optional[float] = Field(default=None, description="Durée d'exécution en secondes")
# Gestion des erreurs
error_message: Optional[str] = Field(default=None, description="Message d'erreur si applicable")
retry_count: int = Field(default=0, ge=0, description="Nombre de tentatives")
max_retries: int = Field(default=3, ge=0, description="Nombre maximum de tentatives")
# Métadonnées spécifiques à l'agent
metadata: Dict[str, Any] = Field(default_factory=dict, description="Données spécifiques à l'agent")
def start_execution(self):
"""Marque le début de l'exécution."""
self.status = AgentStatus.WORKING
self.start_time = datetime.now()
self.end_time = None
def complete_execution(self):
"""Marque la fin réussie de l'exécution."""
self.status = AgentStatus.COMPLETED
self.end_time = datetime.now()
if self.start_time:
self.duration = (self.end_time - self.start_time).total_seconds()
def mark_error(self, error_message: str):
"""Marque l'agent en erreur."""
self.status = AgentStatus.ERROR
self.error_message = error_message
self.end_time = datetime.now()
if self.start_time:
self.duration = (self.end_time - self.start_time).total_seconds()
class Config:
json_schema_extra = {
"example": {
"agent_type": "researcher",
"status": "completed",
"start_time": "2024-01-15T10:00:00Z",
"end_time": "2024-01-15T10:02:30Z",
"duration": 150.0,
"retry_count": 0,
"metadata": {"search_engine": "tavily"}
}
}
class GraphState(BaseModel):
"""
État global du graph LangGraph.
Contient toutes les données partagées entre les agents.
"""
# Identification de la session
session_id: str = Field(..., description="Identifiant unique de la session")
current_step: ProcessingStep = Field(default=ProcessingStep.INIT, description="Étape actuelle du processus")
# Requête initiale
original_query: Optional[ResearchQuery] = Field(default=None, description="Requête de recherche originale")
# États des agents
agents: Dict[AgentType, AgentState] = Field(
default_factory=lambda: {
AgentType.RESEARCHER: AgentState(agent_type=AgentType.RESEARCHER),
AgentType.READER: AgentState(agent_type=AgentType.READER),
AgentType.WRITER: AgentState(agent_type=AgentType.WRITER)
},
description="État de chaque agent"
)
# Données partagées entre agents
research_output: Optional[ResearchOutput] = Field(default=None, description="Résultats de recherche")
summarization_output: Optional[SummarizationOutput] = Field(default=None, description="Résultats de synthèse")
report_output: Optional[ReportOutput] = Field(default=None, description="Rapport final")
# Métadonnées globales
start_time: datetime = Field(default_factory=datetime.now, description="Heure de début du processus")
end_time: Optional[datetime] = Field(default=None, description="Heure de fin du processus")
total_duration: Optional[float] = Field(default=None, description="Durée totale en secondes")
# Configuration et paramètres
config: Dict[str, Any] = Field(default_factory=dict, description="Configuration du processus")
user_preferences: Dict[str, Any] = Field(default_factory=dict, description="Préférences utilisateur")
# Gestion des erreurs globales
global_errors: List[str] = Field(default_factory=list, description="Erreurs globales du processus")
is_successful: bool = Field(default=False, description="Indique si le processus s'est terminé avec succès")
# Informations de débogage
debug_info: Dict[str, Any] = Field(default_factory=dict, description="Informations de débogage")
def get_current_agent(self) -> Optional[AgentType]:
"""Retourne l'agent actuellement en cours d'exécution."""
for agent_type, agent_state in self.agents.items():
if agent_state.status == AgentStatus.WORKING:
return agent_type
return None
def is_agent_completed(self, agent_type: AgentType) -> bool:
"""Vérifie si un agent a terminé son exécution."""
return self.agents[agent_type].status == AgentStatus.COMPLETED
def all_agents_completed(self) -> bool:
"""Vérifie si tous les agents ont terminé."""
return all(
agent.status == AgentStatus.COMPLETED
for agent in self.agents.values()
)
def has_errors(self) -> bool:
"""Vérifie s'il y a des erreurs dans le processus."""
return (
len(self.global_errors) > 0 or
any(agent.status == AgentStatus.ERROR for agent in self.agents.values())
)
def complete_process(self):
"""Marque le processus comme terminé."""
self.end_time = datetime.now()
self.total_duration = (self.end_time - self.start_time).total_seconds()
self.current_step = ProcessingStep.COMPLETED
self.is_successful = not self.has_errors()
def add_global_error(self, error_message: str):
"""Ajoute une erreur globale."""
self.global_errors.append(error_message)
self.current_step = ProcessingStep.ERROR
class Config:
json_schema_extra = {
"example": {
"session_id": "session_123",
"current_step": "research",
"original_query": {
"topic": "impact de l'IA sur l'emploi"
},
"start_time": "2024-01-15T10:00:00Z",
"is_successful": False,
"global_errors": []
}
}
class WorkflowEvent(BaseModel):
"""
Événement dans le workflow LangGraph.
"""
event_id: str = Field(..., description="Identifiant unique de l'événement")
event_type: str = Field(..., description="Type d'événement")
agent_type: Optional[AgentType] = Field(default=None, description="Agent concerné")
timestamp: datetime = Field(default_factory=datetime.now, description="Horodatage de l'événement")
data: Dict[str, Any] = Field(default_factory=dict, description="Données associées à l'événement")
class Config:
json_schema_extra = {
"example": {
"event_id": "evt_001",
"event_type": "agent_started",
"agent_type": "researcher",
"timestamp": "2024-01-15T10:00:00Z",
"data": {"query": "impact IA emploi"}
}
}