Final_Assignment_Template / src /gaio_chat_model.py
rqueraud's picture
Before refactoring tools
4d5f444
# DEPRECATED: This file has been replaced by gemini_chat_model.py
# Please use GeminiChatModel instead of GaioChatModel for LLM integration
import os
import json
import re
from typing import Any, Dict, Iterator, List, Optional
from pydantic import Field, SecretStr
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.messages.tool import ToolCall
try:
# Try relative import first (when used as package)
from .gaio import Gaio
except ImportError:
# Fall back to absolute import (when run directly)
from gaio import Gaio
class GaioChatModel(BaseChatModel):
"""Custom LangChain chat model wrapper for Gaio API.
This model integrates with the Gaio API service to provide chat completion
capabilities within the LangChain framework.
Example:
```python
model = GaioChatModel(
api_key="your-api-key",
api_url="https://your-gaio-endpoint.com/chat/completions"
)
response = model.invoke([HumanMessage(content="Hello!")])
```
"""
api_key: SecretStr = Field(description="API key for Gaio service")
api_url: str = Field(description="API endpoint URL for Gaio service")
model_name: str = Field(default="azure/gpt-4o", description="Name of the model to use")
temperature: float = Field(default=0.05, ge=0.0, le=2.0, description="Sampling temperature")
max_tokens: int = Field(default=1000, gt=0, description="Maximum number of tokens to generate")
gaio_client: Optional[Gaio] = Field(default=None, exclude=True)
class Config:
"""Pydantic model configuration."""
arbitrary_types_allowed = True
def __init__(self, api_key: str, api_url: str, **kwargs):
# Set the fields before calling super().__init__
kwargs['api_key'] = SecretStr(api_key)
kwargs['api_url'] = api_url
super().__init__(**kwargs)
# Initialize the Gaio client after parent initialization
self.gaio_client = Gaio(api_key, api_url)
@property
def _llm_type(self) -> str:
"""Return identifier of the LLM."""
return "gaio"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system for tracing.
Note: API key is excluded for security reasons.
"""
return {
"model_name": self.model_name,
"api_url": self.api_url,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}
def _format_messages_for_gaio(self, messages: List[BaseMessage]) -> str:
"""Convert LangChain messages to a single prompt string for gaio."""
formatted_parts = []
for message in messages:
if isinstance(message, HumanMessage):
formatted_parts.append(f"user: {message.content}")
elif isinstance(message, AIMessage):
formatted_parts.append(f"assistant: {message.content}")
elif isinstance(message, SystemMessage):
formatted_parts.append(f"system: {message.content}")
elif isinstance(message, ToolMessage):
formatted_parts.append(f"tool_result: {message.content}")
# Add instruction after tool result
formatted_parts.append("Now provide your final answer based on the tool result above. Do NOT make another tool call.")
else:
raise RuntimeError(f"Unknown message type: {type(message)}")
# If tools are bound, add tool information to the prompt
if hasattr(self, '_bound_tools') and self._bound_tools:
tool_descriptions = []
for tool in self._bound_tools:
tool_name = tool.name
tool_desc = tool.description
tool_descriptions.append(f"- {tool_name}: {tool_desc}")
tool_format = '{"tool_call": {"name": "tool_name", "arguments": {"parameter_name": "value"}}}'
wikipedia_example = '{"tool_call": {"name": "wikipedia_search", "arguments": {"query": "capital of France"}}}'
youtube_example = '{"tool_call": {"name": "youtube_search", "arguments": {"query": "python tutorial"}}}'
decode_example = '{"tool_call": {"name": "decode_text", "arguments": {"text": "backwards text here"}}}'
tools_prompt = f"""
You have access to the following tools:
{chr(10).join(tool_descriptions)}
When you need to use a tool, you MUST respond with exactly this format:
{tool_format}
Examples:
- To search Wikipedia: {wikipedia_example}
- To search YouTube: {youtube_example}
- To decode text: {decode_example}
CRITICAL: Use the correct parameter names:
- wikipedia_search and youtube_search use "query"
- decode_text uses "text"
Always try tools first for factual information before saying you cannot help."""
formatted_parts.append(tools_prompt)
return "\n\n".join(formatted_parts)
def _parse_tool_calls(self, response_content: str) -> tuple[str, List[ToolCall]]:
"""Parse tool calls from the response content."""
tool_calls = []
remaining_content = response_content
# Look for JSON tool call pattern - more flexible regex
tool_call_pattern = r'\{"tool_call":\s*\{"name":\s*"([^"]+)",\s*"arguments":\s*(\{[^}]*\})\}\}'
matches = list(re.finditer(tool_call_pattern, response_content))
for i, match in enumerate(matches):
tool_name = match.group(1)
try:
arguments_str = match.group(2)
arguments = json.loads(arguments_str)
tool_call = ToolCall(
name=tool_name,
args=arguments,
id=f"call_{len(tool_calls)}"
)
tool_calls.append(tool_call)
# Remove the tool call from the content
remaining_content = remaining_content.replace(match.group(0), "").strip()
except json.JSONDecodeError:
continue
return remaining_content, tool_calls
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response from the model."""
# Convert messages to prompt format
prompt = self._format_messages_for_gaio(messages)
# Call gaio API
try:
response_content = self.gaio_client.InvokeGaio(prompt)
# Parse any tool calls from the response
content, tool_calls = self._parse_tool_calls(response_content)
# Estimate token usage (simple approximation)
input_tokens = self._estimate_tokens(prompt)
output_tokens = self._estimate_tokens(content)
usage_metadata = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
}
# Create AI message with tool calls if any
if tool_calls:
ai_message = AIMessage(
content=content,
tool_calls=tool_calls,
usage_metadata=usage_metadata,
response_metadata={"model": self.model_name}
)
else:
ai_message = AIMessage(
content=content,
usage_metadata=usage_metadata,
response_metadata={"model": self.model_name}
)
# Create chat generation
generation = ChatGeneration(
message=ai_message,
generation_info={"model": self.model_name}
)
return ChatResult(generations=[generation])
except Exception as e:
raise RuntimeError(f"Error calling Gaio API: {e}")
def _estimate_tokens(self, text: str) -> int:
"""Simple token estimation (roughly 4 characters per token for English)."""
return max(1, len(text) // 4)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Async generate - for now, just call the sync version."""
# For simplicity, we'll use the sync version
# In production, you might want to implement true async using aiohttp
return self._generate(messages, stop, run_manager, **kwargs)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the response. Since Gaio doesn't support streaming, simulate it."""
# Get the full response first
result = self._generate(messages, stop, run_manager, **kwargs)
message = result.generations[0].message
# Stream character by character to simulate streaming
content = message.content
for i, char in enumerate(content):
chunk_content = char
if i == len(content) - 1: # Last chunk gets full metadata
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=chunk_content,
usage_metadata=message.usage_metadata,
response_metadata=message.response_metadata,
tool_calls=getattr(message, 'tool_calls', None) if i == len(content) - 1 else None
)
)
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=chunk_content)
)
if run_manager:
run_manager.on_llm_new_token(char, chunk=chunk)
yield chunk
def bind_tools(self, tools: List[Any], **kwargs: Any) -> "GaioChatModel":
"""Bind tools to the model."""
# Create a copy of the current model with tools bound
bound_model = GaioChatModel(
api_key=self.api_key.get_secret_value(),
api_url=self.api_url,
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens
)
# Store the tools for potential use in generation
bound_model._bound_tools = tools
return bound_model
def main():
"""Test GaioChatModel with a simple question and verify the answer."""
print("Testing GaioChatModel with a simple math question...")
# Get API credentials from environment variables
api_key = os.getenv("GAIO_API_TOKEN")
api_url = os.getenv("GAIO_URL")
if not api_key or not api_url:
print("❌ Test failed: Missing environment variables.")
print("Please set the following environment variables:")
print("- GAIO_API_TOKEN: Your API token")
print("- GAIO_URL: The API URL")
return
try:
# Create GaioChatModel instance
chat_model = GaioChatModel(api_key=api_key, api_url=api_url)
# Test with the specific question using LangChain message format
test_question = "How much is 2 + 2 ? Only answer with the response number and nothing else."
messages = [HumanMessage(content=test_question)]
print(f"\nQuestion: {test_question}")
print("Using LangChain message format...")
# Get the answer using LangChain's invoke method
result = chat_model.invoke(messages)
answer = result.content
print(f"Answer: '{answer}'")
# Check if the answer is exactly "4"
answer_stripped = answer.strip()
if answer_stripped == "4":
print("✅ Test passed! GaioChatModel correctly answered '4'.")
else:
print(f"❌ Test failed. Expected '4', but got '{answer_stripped}'.")
except Exception as e:
print(f"❌ Test failed with error: {e}")
if __name__ == "__main__":
main()