Spaces:
Runtime error
Runtime error
| # DEPRECATED: This file has been replaced by gemini_chat_model.py | |
| # Please use GeminiChatModel instead of GaioChatModel for LLM integration | |
| import os | |
| import json | |
| import re | |
| from typing import Any, Dict, Iterator, List, Optional | |
| from pydantic import Field, SecretStr | |
| from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ToolMessage | |
| from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
| from langchain_core.messages.tool import ToolCall | |
| try: | |
| # Try relative import first (when used as package) | |
| from .gaio import Gaio | |
| except ImportError: | |
| # Fall back to absolute import (when run directly) | |
| from gaio import Gaio | |
| class GaioChatModel(BaseChatModel): | |
| """Custom LangChain chat model wrapper for Gaio API. | |
| This model integrates with the Gaio API service to provide chat completion | |
| capabilities within the LangChain framework. | |
| Example: | |
| ```python | |
| model = GaioChatModel( | |
| api_key="your-api-key", | |
| api_url="https://your-gaio-endpoint.com/chat/completions" | |
| ) | |
| response = model.invoke([HumanMessage(content="Hello!")]) | |
| ``` | |
| """ | |
| api_key: SecretStr = Field(description="API key for Gaio service") | |
| api_url: str = Field(description="API endpoint URL for Gaio service") | |
| model_name: str = Field(default="azure/gpt-4o", description="Name of the model to use") | |
| temperature: float = Field(default=0.05, ge=0.0, le=2.0, description="Sampling temperature") | |
| max_tokens: int = Field(default=1000, gt=0, description="Maximum number of tokens to generate") | |
| gaio_client: Optional[Gaio] = Field(default=None, exclude=True) | |
| class Config: | |
| """Pydantic model configuration.""" | |
| arbitrary_types_allowed = True | |
| def __init__(self, api_key: str, api_url: str, **kwargs): | |
| # Set the fields before calling super().__init__ | |
| kwargs['api_key'] = SecretStr(api_key) | |
| kwargs['api_url'] = api_url | |
| super().__init__(**kwargs) | |
| # Initialize the Gaio client after parent initialization | |
| self.gaio_client = Gaio(api_key, api_url) | |
| def _llm_type(self) -> str: | |
| """Return identifier of the LLM.""" | |
| return "gaio" | |
| def _identifying_params(self) -> Dict[str, Any]: | |
| """Return a dictionary of identifying parameters. | |
| This information is used by the LangChain callback system for tracing. | |
| Note: API key is excluded for security reasons. | |
| """ | |
| return { | |
| "model_name": self.model_name, | |
| "api_url": self.api_url, | |
| "temperature": self.temperature, | |
| "max_tokens": self.max_tokens, | |
| } | |
| def _format_messages_for_gaio(self, messages: List[BaseMessage]) -> str: | |
| """Convert LangChain messages to a single prompt string for gaio.""" | |
| formatted_parts = [] | |
| for message in messages: | |
| if isinstance(message, HumanMessage): | |
| formatted_parts.append(f"user: {message.content}") | |
| elif isinstance(message, AIMessage): | |
| formatted_parts.append(f"assistant: {message.content}") | |
| elif isinstance(message, SystemMessage): | |
| formatted_parts.append(f"system: {message.content}") | |
| elif isinstance(message, ToolMessage): | |
| formatted_parts.append(f"tool_result: {message.content}") | |
| # Add instruction after tool result | |
| formatted_parts.append("Now provide your final answer based on the tool result above. Do NOT make another tool call.") | |
| else: | |
| raise RuntimeError(f"Unknown message type: {type(message)}") | |
| # If tools are bound, add tool information to the prompt | |
| if hasattr(self, '_bound_tools') and self._bound_tools: | |
| tool_descriptions = [] | |
| for tool in self._bound_tools: | |
| tool_name = tool.name | |
| tool_desc = tool.description | |
| tool_descriptions.append(f"- {tool_name}: {tool_desc}") | |
| tool_format = '{"tool_call": {"name": "tool_name", "arguments": {"parameter_name": "value"}}}' | |
| wikipedia_example = '{"tool_call": {"name": "wikipedia_search", "arguments": {"query": "capital of France"}}}' | |
| youtube_example = '{"tool_call": {"name": "youtube_search", "arguments": {"query": "python tutorial"}}}' | |
| decode_example = '{"tool_call": {"name": "decode_text", "arguments": {"text": "backwards text here"}}}' | |
| tools_prompt = f""" | |
| You have access to the following tools: | |
| {chr(10).join(tool_descriptions)} | |
| When you need to use a tool, you MUST respond with exactly this format: | |
| {tool_format} | |
| Examples: | |
| - To search Wikipedia: {wikipedia_example} | |
| - To search YouTube: {youtube_example} | |
| - To decode text: {decode_example} | |
| CRITICAL: Use the correct parameter names: | |
| - wikipedia_search and youtube_search use "query" | |
| - decode_text uses "text" | |
| Always try tools first for factual information before saying you cannot help.""" | |
| formatted_parts.append(tools_prompt) | |
| return "\n\n".join(formatted_parts) | |
| def _parse_tool_calls(self, response_content: str) -> tuple[str, List[ToolCall]]: | |
| """Parse tool calls from the response content.""" | |
| tool_calls = [] | |
| remaining_content = response_content | |
| # Look for JSON tool call pattern - more flexible regex | |
| tool_call_pattern = r'\{"tool_call":\s*\{"name":\s*"([^"]+)",\s*"arguments":\s*(\{[^}]*\})\}\}' | |
| matches = list(re.finditer(tool_call_pattern, response_content)) | |
| for i, match in enumerate(matches): | |
| tool_name = match.group(1) | |
| try: | |
| arguments_str = match.group(2) | |
| arguments = json.loads(arguments_str) | |
| tool_call = ToolCall( | |
| name=tool_name, | |
| args=arguments, | |
| id=f"call_{len(tool_calls)}" | |
| ) | |
| tool_calls.append(tool_call) | |
| # Remove the tool call from the content | |
| remaining_content = remaining_content.replace(match.group(0), "").strip() | |
| except json.JSONDecodeError: | |
| continue | |
| return remaining_content, tool_calls | |
| def _generate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| """Generate a response from the model.""" | |
| # Convert messages to prompt format | |
| prompt = self._format_messages_for_gaio(messages) | |
| # Call gaio API | |
| try: | |
| response_content = self.gaio_client.InvokeGaio(prompt) | |
| # Parse any tool calls from the response | |
| content, tool_calls = self._parse_tool_calls(response_content) | |
| # Estimate token usage (simple approximation) | |
| input_tokens = self._estimate_tokens(prompt) | |
| output_tokens = self._estimate_tokens(content) | |
| usage_metadata = { | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "total_tokens": input_tokens + output_tokens | |
| } | |
| # Create AI message with tool calls if any | |
| if tool_calls: | |
| ai_message = AIMessage( | |
| content=content, | |
| tool_calls=tool_calls, | |
| usage_metadata=usage_metadata, | |
| response_metadata={"model": self.model_name} | |
| ) | |
| else: | |
| ai_message = AIMessage( | |
| content=content, | |
| usage_metadata=usage_metadata, | |
| response_metadata={"model": self.model_name} | |
| ) | |
| # Create chat generation | |
| generation = ChatGeneration( | |
| message=ai_message, | |
| generation_info={"model": self.model_name} | |
| ) | |
| return ChatResult(generations=[generation]) | |
| except Exception as e: | |
| raise RuntimeError(f"Error calling Gaio API: {e}") | |
| def _estimate_tokens(self, text: str) -> int: | |
| """Simple token estimation (roughly 4 characters per token for English).""" | |
| return max(1, len(text) // 4) | |
| async def _agenerate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| """Async generate - for now, just call the sync version.""" | |
| # For simplicity, we'll use the sync version | |
| # In production, you might want to implement true async using aiohttp | |
| return self._generate(messages, stop, run_manager, **kwargs) | |
| def _stream( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> Iterator[ChatGenerationChunk]: | |
| """Stream the response. Since Gaio doesn't support streaming, simulate it.""" | |
| # Get the full response first | |
| result = self._generate(messages, stop, run_manager, **kwargs) | |
| message = result.generations[0].message | |
| # Stream character by character to simulate streaming | |
| content = message.content | |
| for i, char in enumerate(content): | |
| chunk_content = char | |
| if i == len(content) - 1: # Last chunk gets full metadata | |
| chunk = ChatGenerationChunk( | |
| message=AIMessageChunk( | |
| content=chunk_content, | |
| usage_metadata=message.usage_metadata, | |
| response_metadata=message.response_metadata, | |
| tool_calls=getattr(message, 'tool_calls', None) if i == len(content) - 1 else None | |
| ) | |
| ) | |
| else: | |
| chunk = ChatGenerationChunk( | |
| message=AIMessageChunk(content=chunk_content) | |
| ) | |
| if run_manager: | |
| run_manager.on_llm_new_token(char, chunk=chunk) | |
| yield chunk | |
| def bind_tools(self, tools: List[Any], **kwargs: Any) -> "GaioChatModel": | |
| """Bind tools to the model.""" | |
| # Create a copy of the current model with tools bound | |
| bound_model = GaioChatModel( | |
| api_key=self.api_key.get_secret_value(), | |
| api_url=self.api_url, | |
| model_name=self.model_name, | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens | |
| ) | |
| # Store the tools for potential use in generation | |
| bound_model._bound_tools = tools | |
| return bound_model | |
| def main(): | |
| """Test GaioChatModel with a simple question and verify the answer.""" | |
| print("Testing GaioChatModel with a simple math question...") | |
| # Get API credentials from environment variables | |
| api_key = os.getenv("GAIO_API_TOKEN") | |
| api_url = os.getenv("GAIO_URL") | |
| if not api_key or not api_url: | |
| print("❌ Test failed: Missing environment variables.") | |
| print("Please set the following environment variables:") | |
| print("- GAIO_API_TOKEN: Your API token") | |
| print("- GAIO_URL: The API URL") | |
| return | |
| try: | |
| # Create GaioChatModel instance | |
| chat_model = GaioChatModel(api_key=api_key, api_url=api_url) | |
| # Test with the specific question using LangChain message format | |
| test_question = "How much is 2 + 2 ? Only answer with the response number and nothing else." | |
| messages = [HumanMessage(content=test_question)] | |
| print(f"\nQuestion: {test_question}") | |
| print("Using LangChain message format...") | |
| # Get the answer using LangChain's invoke method | |
| result = chat_model.invoke(messages) | |
| answer = result.content | |
| print(f"Answer: '{answer}'") | |
| # Check if the answer is exactly "4" | |
| answer_stripped = answer.strip() | |
| if answer_stripped == "4": | |
| print("✅ Test passed! GaioChatModel correctly answered '4'.") | |
| else: | |
| print(f"❌ Test failed. Expected '4', but got '{answer_stripped}'.") | |
| except Exception as e: | |
| print(f"❌ Test failed with error: {e}") | |
| if __name__ == "__main__": | |
| main() | |