| """ |
| Base agent class for all CV evaluation agents. |
| Supports multiple LLM providers: OpenAI (ChatGPT), Google (Gemini) via LangChain. |
| """ |
|
|
| import json |
| import logging |
| import os |
| import re |
| from typing import Any, Literal, TypeVar |
|
|
| from langchain_core.language_models.chat_models import BaseChatModel |
| from langchain_core.messages import HumanMessage, SystemMessage |
| from pydantic import BaseModel, ValidationError |
| from tenacity import ( |
| retry, |
| retry_if_exception_type, |
| stop_after_attempt, |
| wait_exponential, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| T = TypeVar("T", bound=BaseModel) |
|
|
| |
| ProviderType = Literal["gemini", "openai", "ollama"] |
|
|
|
|
| def create_llm( |
| provider: ProviderType, |
| model_name: str | None, |
| temperature: float, |
| api_key: str | None, |
| ) -> BaseChatModel: |
| """Factory function to instantiate the correct LangChain LLM based on provider.""" |
|
|
| if provider == "gemini": |
| from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
| resolved_key = api_key or os.getenv("GOOGLE_API_KEY", "") |
| resolved_model = model_name or os.getenv("GEMINI_MODEL", "gemini-1.5-flash") |
|
|
| if not resolved_key: |
| raise ValueError( |
| "GOOGLE_API_KEY not found. Set it in .env or pass it directly." |
| ) |
|
|
| return ChatGoogleGenerativeAI( |
| model=resolved_model, |
| google_api_key=resolved_key, |
| temperature=temperature, |
| convert_system_message_to_human=True, |
| ) |
|
|
| elif provider == "openai": |
| from langchain_openai import ChatOpenAI |
|
|
| resolved_key = api_key or os.getenv("OPENAI_API_KEY", "") |
| resolved_model = model_name or os.getenv("OPENAI_MODEL", "gpt-4o-mini") |
|
|
| if not resolved_key: |
| raise ValueError( |
| "OPENAI_API_KEY not found. Set it in .env or pass it directly." |
| ) |
|
|
| return ChatOpenAI( |
| model=resolved_model, |
| openai_api_key=resolved_key, |
| temperature=temperature, |
| request_timeout=60, |
| max_retries=2, |
| ) |
|
|
| elif provider == "ollama": |
| from langchain_openai import ChatOpenAI |
|
|
| |
| resolved_key = api_key or os.getenv("OLLAMA_API_KEY") |
| resolved_model = model_name or os.getenv("OLLAMA_MODEL", "glm-5.1:cloud") |
|
|
| if not resolved_key: |
| raise ValueError( |
| "OLLAMA_API_KEY not found. Set it in .env or pass it directly." |
| ) |
|
|
| return ChatOpenAI( |
| model=resolved_model, |
| base_url="https://ollama.com/v1", |
| openai_api_key=resolved_key, |
| temperature=temperature, |
| request_timeout=120, |
| max_retries=2, |
| ) |
|
|
| else: |
| raise ValueError( |
| f"Unsupported provider: '{provider}'. Choose 'gemini', 'openai', or 'ollama'." |
| ) |
|
|
|
|
| class BaseAgent: |
| """ |
| Base class for all CV evaluation agents. |
| Supports OpenAI (ChatGPT) and Google (Gemini) via LangChain. |
| """ |
|
|
| def __init__( |
| self, |
| name: str, |
| role: str, |
| provider: ProviderType = "gemini", |
| model_name: str | None = None, |
| temperature: float = 0, |
| api_key: str | None = None, |
| ): |
| self.name = name |
| self.role = role |
| self.provider = provider |
| self.model_name = model_name |
| self.temperature = temperature |
|
|
| self.llm: BaseChatModel = create_llm( |
| provider=provider, |
| model_name=model_name, |
| temperature=temperature, |
| api_key=api_key, |
| ) |
|
|
| logger.info( |
| f"[{self.name}] Initialized with provider='{provider}', model='{self.model_name or 'default'}'" |
| ) |
|
|
| def _build_messages(self, prompt: str) -> list: |
| """ |
| Build the message list for the LLM. |
| Gemini uses convert_system_message_to_human, so SystemMessage is safe for OpenAI |
| and handled automatically for Gemini. |
| """ |
| messages = [] |
| if self.role: |
| messages.append(SystemMessage(content=self.role)) |
| messages.append(HumanMessage(content=prompt)) |
| return messages |
|
|
| def _extract_json_from_response(self, text: str) -> str: |
| """Extract JSON from LLM response, handling markdown code blocks.""" |
| patterns = [ |
| r"```json\s*([\s\S]*?)```", |
| r"```\s*([\s\S]*?)```", |
| r"(\{[\s\S]*\})", |
| ] |
| for pattern in patterns: |
| match = re.search(pattern, text) |
| if match: |
| candidate = match.group(1).strip() |
| try: |
| json.loads(candidate) |
| return candidate |
| except json.JSONDecodeError: |
| continue |
|
|
| return text.strip() |
|
|
| @retry( |
| stop=stop_after_attempt(3), |
| wait=wait_exponential(multiplier=1, min=2, max=10), |
| retry=retry_if_exception_type((json.JSONDecodeError, ValidationError)), |
| reraise=True, |
| ) |
| def _call_llm_with_retry(self, prompt: str, output_model: type[T]) -> T: |
| """Call LLM with retry logic for JSON parsing failures.""" |
| logger.info(f"[{self.name}] Calling {self.provider} LLM...") |
|
|
| messages = self._build_messages(prompt) |
| response = self.llm.invoke(messages) |
| raw_text = response.content |
|
|
| logger.debug(f"[{self.name}] Raw response:\n{raw_text}") |
|
|
| json_str = self._extract_json_from_response(raw_text) |
|
|
| |
| try: |
| data = json.loads(json_str) |
| except json.JSONDecodeError as e: |
| logger.warning(f"[{self.name}] JSON parse error: {e}. Asking LLM to fix...") |
| fix_prompt = ( |
| f"Le texte suivant devait être un JSON valide mais contient des erreurs. " |
| f"Corrige-le et renvoie UNIQUEMENT le JSON valide :\n\n{json_str}" |
| ) |
| fix_response = self.llm.invoke([HumanMessage(content=fix_prompt)]) |
| json_str = self._extract_json_from_response(fix_response.content) |
| data = json.loads(json_str) |
|
|
| |
| try: |
| result = output_model.model_validate(data) |
| except ValidationError as e: |
| |
| logger.warning( |
| f"[{self.name}] Pydantic ValidationError:\n{e}\n" |
| f"Data received:\n{json.dumps(data, indent=2, ensure_ascii=False)}" |
| ) |
| |
| schema = json.dumps( |
| output_model.model_json_schema(), indent=2, ensure_ascii=False |
| ) |
| fix_prompt = ( |
| f"Le JSON suivant ne respecte pas le schéma attendu.\n\n" |
| f"SCHÉMA:\n{schema}\n\n" |
| f"JSON REÇU:\n{json.dumps(data, indent=2, ensure_ascii=False)}\n\n" |
| f"ERREURS DE VALIDATION:\n{str(e)}\n\n" |
| f"Renvoie UNIQUEMENT un JSON corrigé qui respecte exactement le schéma." |
| ) |
| fix_response = self.llm.invoke([HumanMessage(content=fix_prompt)]) |
| json_str = self._extract_json_from_response(fix_response.content) |
| data = json.loads(json_str) |
| result = output_model.model_validate(data) |
|
|
| logger.info(f"[{self.name}] Successfully parsed and validated output.") |
| return result |
|
|
| def run(self, **kwargs) -> Any: |
| """Override in subclasses.""" |
| raise NotImplementedError("Subclasses must implement run()") |
|
|