Spaces:
Running
Running
File size: 1,949 Bytes
e961630 d81d6f6 e961630 d81d6f6 e961630 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | #!/usr/bin/env python3
"""
LLM Response Normalization Wrapper
Normalizes Gemini's list content format to string for consistency with OpenAI.
"""
from typing import Any
from langchain_core.messages import AIMessage
class NormalizedLLM:
"""
Simple wrapper that normalizes LLM responses across providers.
Usage:
gemini_llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
normalized = NormalizedLLM(gemini_llm)
response = normalized.invoke("Hello") # response.content is always a string
"""
def __init__(self, llm: Any):
self.llm = llm
def _normalize_content(self, content: Any) -> str:
"""Convert Gemini's list format to string."""
if isinstance(content, str):
return content
if isinstance(content, list):
return "".join(item.get('text', str(item)) for item in content)
return str(content)
def invoke(self, input: Any, config: Any = None, **kwargs) -> AIMessage:
# Filter out unsupported parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'extra_body'}
response = self.llm.invoke(input, config=config, **filtered_kwargs)
response.content = self._normalize_content(response.content)
return response
async def ainvoke(self, input: Any, config: Any = None, **kwargs) -> AIMessage:
# Filter out unsupported parameters
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'extra_body'}
response = await self.llm.ainvoke(input, config=config, **filtered_kwargs)
response.content = self._normalize_content(response.content)
return response
def bind_tools(self, tools, **kwargs):
"""Bind tools and return wrapped instance."""
return NormalizedLLM(self.llm.bind_tools(tools, **kwargs))
def __getattr__(self, name: str) -> Any:
return getattr(self.llm, name) |