File size: 8,316 Bytes
8a848a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
Modèles d'état pour l'orchestration LangGraph.
Définit l'état global du système et les états des agents.
"""
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
from .research_models import ResearchQuery, ResearchOutput
from .document_models import SummarizationOutput
from .report_models import ReportOutput
class AgentType(str, Enum):
"""Types d'agents dans le système."""
RESEARCHER = "researcher"
CONTENT_EXTRACTOR = "content_extractor"
READER = "reader"
WRITER = "writer"
class AgentStatus(str, Enum):
"""Statuts possibles d'un agent."""
IDLE = "idle"
WORKING = "working"
COMPLETED = "completed"
ERROR = "error"
TIMEOUT = "timeout"
class ProcessingStep(str, Enum):
"""Étapes du processus de recherche."""
INIT = "init"
RESEARCH = "research"
READING = "reading"
WRITING = "writing"
COMPLETED = "completed"
ERROR = "error"
class AgentState(BaseModel):
"""
État individuel d'un agent.
"""
agent_type: AgentType = Field(..., description="Type de l'agent")
status: AgentStatus = Field(default=AgentStatus.IDLE, description="Statut actuel")
# Informations de timing
start_time: Optional[datetime] = Field(default=None, description="Heure de début d'exécution")
end_time: Optional[datetime] = Field(default=None, description="Heure de fin d'exécution")
duration: Optional[float] = Field(default=None, description="Durée d'exécution en secondes")
# Gestion des erreurs
error_message: Optional[str] = Field(default=None, description="Message d'erreur si applicable")
retry_count: int = Field(default=0, ge=0, description="Nombre de tentatives")
max_retries: int = Field(default=3, ge=0, description="Nombre maximum de tentatives")
# Métadonnées spécifiques à l'agent
metadata: Dict[str, Any] = Field(default_factory=dict, description="Données spécifiques à l'agent")
def start_execution(self):
"""Marque le début de l'exécution."""
self.status = AgentStatus.WORKING
self.start_time = datetime.now()
self.end_time = None
def complete_execution(self):
"""Marque la fin réussie de l'exécution."""
self.status = AgentStatus.COMPLETED
self.end_time = datetime.now()
if self.start_time:
self.duration = (self.end_time - self.start_time).total_seconds()
def mark_error(self, error_message: str):
"""Marque l'agent en erreur."""
self.status = AgentStatus.ERROR
self.error_message = error_message
self.end_time = datetime.now()
if self.start_time:
self.duration = (self.end_time - self.start_time).total_seconds()
class Config:
json_schema_extra = {
"example": {
"agent_type": "researcher",
"status": "completed",
"start_time": "2024-01-15T10:00:00Z",
"end_time": "2024-01-15T10:02:30Z",
"duration": 150.0,
"retry_count": 0,
"metadata": {"search_engine": "tavily"}
}
}
class GraphState(BaseModel):
"""
État global du graph LangGraph.
Contient toutes les données partagées entre les agents.
"""
# Identification de la session
session_id: str = Field(..., description="Identifiant unique de la session")
current_step: ProcessingStep = Field(default=ProcessingStep.INIT, description="Étape actuelle du processus")
# Requête initiale
original_query: Optional[ResearchQuery] = Field(default=None, description="Requête de recherche originale")
# États des agents
agents: Dict[AgentType, AgentState] = Field(
default_factory=lambda: {
AgentType.RESEARCHER: AgentState(agent_type=AgentType.RESEARCHER),
AgentType.READER: AgentState(agent_type=AgentType.READER),
AgentType.WRITER: AgentState(agent_type=AgentType.WRITER)
},
description="État de chaque agent"
)
# Données partagées entre agents
research_output: Optional[ResearchOutput] = Field(default=None, description="Résultats de recherche")
summarization_output: Optional[SummarizationOutput] = Field(default=None, description="Résultats de synthèse")
report_output: Optional[ReportOutput] = Field(default=None, description="Rapport final")
# Métadonnées globales
start_time: datetime = Field(default_factory=datetime.now, description="Heure de début du processus")
end_time: Optional[datetime] = Field(default=None, description="Heure de fin du processus")
total_duration: Optional[float] = Field(default=None, description="Durée totale en secondes")
# Configuration et paramètres
config: Dict[str, Any] = Field(default_factory=dict, description="Configuration du processus")
user_preferences: Dict[str, Any] = Field(default_factory=dict, description="Préférences utilisateur")
# Gestion des erreurs globales
global_errors: List[str] = Field(default_factory=list, description="Erreurs globales du processus")
is_successful: bool = Field(default=False, description="Indique si le processus s'est terminé avec succès")
# Informations de débogage
debug_info: Dict[str, Any] = Field(default_factory=dict, description="Informations de débogage")
def get_current_agent(self) -> Optional[AgentType]:
"""Retourne l'agent actuellement en cours d'exécution."""
for agent_type, agent_state in self.agents.items():
if agent_state.status == AgentStatus.WORKING:
return agent_type
return None
def is_agent_completed(self, agent_type: AgentType) -> bool:
"""Vérifie si un agent a terminé son exécution."""
return self.agents[agent_type].status == AgentStatus.COMPLETED
def all_agents_completed(self) -> bool:
"""Vérifie si tous les agents ont terminé."""
return all(
agent.status == AgentStatus.COMPLETED
for agent in self.agents.values()
)
def has_errors(self) -> bool:
"""Vérifie s'il y a des erreurs dans le processus."""
return (
len(self.global_errors) > 0 or
any(agent.status == AgentStatus.ERROR for agent in self.agents.values())
)
def complete_process(self):
"""Marque le processus comme terminé."""
self.end_time = datetime.now()
self.total_duration = (self.end_time - self.start_time).total_seconds()
self.current_step = ProcessingStep.COMPLETED
self.is_successful = not self.has_errors()
def add_global_error(self, error_message: str):
"""Ajoute une erreur globale."""
self.global_errors.append(error_message)
self.current_step = ProcessingStep.ERROR
class Config:
json_schema_extra = {
"example": {
"session_id": "session_123",
"current_step": "research",
"original_query": {
"topic": "impact de l'IA sur l'emploi"
},
"start_time": "2024-01-15T10:00:00Z",
"is_successful": False,
"global_errors": []
}
}
class WorkflowEvent(BaseModel):
"""
Événement dans le workflow LangGraph.
"""
event_id: str = Field(..., description="Identifiant unique de l'événement")
event_type: str = Field(..., description="Type d'événement")
agent_type: Optional[AgentType] = Field(default=None, description="Agent concerné")
timestamp: datetime = Field(default_factory=datetime.now, description="Horodatage de l'événement")
data: Dict[str, Any] = Field(default_factory=dict, description="Données associées à l'événement")
class Config:
json_schema_extra = {
"example": {
"event_id": "evt_001",
"event_type": "agent_started",
"agent_type": "researcher",
"timestamp": "2024-01-15T10:00:00Z",
"data": {"query": "impact IA emploi"}
}
} |