CyberLegalAIendpoint / utils /llm_wrapper.py
Charles Grandjean
oui cetait vraiment necessaire d avoir un client openai incompatible avec les autres
d81d6f6
#!/usr/bin/env python3
"""
LLM Response Normalization Wrapper
Normalizes Gemini's list content format to string for consistency with OpenAI.
"""
from typing import Any
from langchain_core.messages import AIMessage
class NormalizedLLM:
"""
Simple wrapper that normalizes LLM responses across providers.
Usage:
gemini_llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
normalized = NormalizedLLM(gemini_llm)
response = normalized.invoke("Hello") # response.content is always a string
"""
def __init__(self, llm: Any):
self.llm = llm
def _normalize_content(self, content: Any) -> str:
"""Convert Gemini's list format to string."""
if isinstance(content, str):
return content
if isinstance(content, list):
return "".join(item.get('text', str(item)) for item in content)
return str(content)
def invoke(self, input: Any, config: Any = None, **kwargs) -> AIMessage:
# Filter out unsupported parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'extra_body'}
response = self.llm.invoke(input, config=config, **filtered_kwargs)
response.content = self._normalize_content(response.content)
return response
async def ainvoke(self, input: Any, config: Any = None, **kwargs) -> AIMessage:
# Filter out unsupported parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'extra_body'}
response = await self.llm.ainvoke(input, config=config, **filtered_kwargs)
response.content = self._normalize_content(response.content)
return response
def bind_tools(self, tools, **kwargs):
"""Bind tools and return wrapped instance."""
return NormalizedLLM(self.llm.bind_tools(tools, **kwargs))
def __getattr__(self, name: str) -> Any:
return getattr(self.llm, name)