Spaces:
Sleeping
Sleeping
updated
Browse files
src.py
CHANGED
|
@@ -1,1746 +1,518 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import asyncio
|
| 4 |
-
import base64
|
| 5 |
-
import json
|
| 6 |
-
import logging
|
| 7 |
-
import mimetypes
|
| 8 |
-
import uuid
|
| 9 |
-
import warnings
|
| 10 |
-
from difflib import get_close_matches
|
| 11 |
-
from operator import itemgetter
|
| 12 |
-
from typing import (
|
| 13 |
-
Any,
|
| 14 |
-
AsyncIterator,
|
| 15 |
-
Callable,
|
| 16 |
-
Dict,
|
| 17 |
-
Iterator,
|
| 18 |
-
List,
|
| 19 |
-
Mapping,
|
| 20 |
-
Optional,
|
| 21 |
-
Sequence,
|
| 22 |
-
Tuple,
|
| 23 |
-
Type,
|
| 24 |
-
Union,
|
| 25 |
-
cast,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
import filetype # type: ignore[import]
|
| 29 |
-
import google.api_core
|
| 30 |
-
|
| 31 |
-
# TODO: remove ignore once the google package is published with types
|
| 32 |
-
import proto # type: ignore[import]
|
| 33 |
-
from google.ai.generativelanguage_v1beta import (
|
| 34 |
-
GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
|
| 35 |
-
)
|
| 36 |
-
from google.ai.generativelanguage_v1beta.types import (
|
| 37 |
-
Blob,
|
| 38 |
-
Candidate,
|
| 39 |
-
CodeExecution,
|
| 40 |
-
Content,
|
| 41 |
-
FileData,
|
| 42 |
-
FunctionCall,
|
| 43 |
-
FunctionDeclaration,
|
| 44 |
-
FunctionResponse,
|
| 45 |
-
GenerateContentRequest,
|
| 46 |
-
GenerateContentResponse,
|
| 47 |
-
GenerationConfig,
|
| 48 |
-
Part,
|
| 49 |
-
SafetySetting,
|
| 50 |
-
ToolConfig,
|
| 51 |
-
VideoMetadata,
|
| 52 |
-
)
|
| 53 |
-
from google.ai.generativelanguage_v1beta.types import Tool as GoogleTool
|
| 54 |
-
from langchain_core.callbacks.manager import (
|
| 55 |
-
AsyncCallbackManagerForLLMRun,
|
| 56 |
-
CallbackManagerForLLMRun,
|
| 57 |
-
)
|
| 58 |
-
from langchain_core.language_models import LanguageModelInput
|
| 59 |
-
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
| 60 |
-
from langchain_core.messages import (
|
| 61 |
-
AIMessage,
|
| 62 |
-
AIMessageChunk,
|
| 63 |
-
BaseMessage,
|
| 64 |
-
FunctionMessage,
|
| 65 |
-
HumanMessage,
|
| 66 |
-
SystemMessage,
|
| 67 |
-
ToolMessage,
|
| 68 |
-
is_data_content_block,
|
| 69 |
-
)
|
| 70 |
-
from langchain_core.messages.ai import UsageMetadata
|
| 71 |
-
from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
|
| 72 |
-
from langchain_core.output_parsers.base import OutputParserLike
|
| 73 |
-
from langchain_core.output_parsers.openai_tools import (
|
| 74 |
-
JsonOutputKeyToolsParser,
|
| 75 |
-
PydanticToolsParser,
|
| 76 |
-
parse_tool_calls,
|
| 77 |
-
)
|
| 78 |
-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
| 79 |
-
from langchain_core.runnables import Runnable, RunnableConfig, RunnablePassthrough
|
| 80 |
-
from langchain_core.tools import BaseTool
|
| 81 |
-
from langchain_core.utils import get_pydantic_field_names
|
| 82 |
-
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 83 |
-
from langchain_core.utils.utils import _build_model_kwargs
|
| 84 |
-
from pydantic import (
|
| 85 |
-
BaseModel,
|
| 86 |
-
ConfigDict,
|
| 87 |
-
Field,
|
| 88 |
-
SecretStr,
|
| 89 |
-
model_validator,
|
| 90 |
-
)
|
| 91 |
-
from tenacity import (
|
| 92 |
-
before_sleep_log,
|
| 93 |
-
retry,
|
| 94 |
-
retry_if_exception_type,
|
| 95 |
-
stop_after_attempt,
|
| 96 |
-
wait_exponential,
|
| 97 |
-
)
|
| 98 |
-
from typing_extensions import Self, is_typeddict
|
| 99 |
-
|
| 100 |
-
from langchain_google_genai._common import (
|
| 101 |
-
GoogleGenerativeAIError,
|
| 102 |
-
SafetySettingDict,
|
| 103 |
-
_BaseGoogleGenerativeAI,
|
| 104 |
-
get_client_info,
|
| 105 |
-
)
|
| 106 |
-
from langchain_google_genai._function_utils import (
|
| 107 |
-
_tool_choice_to_tool_config,
|
| 108 |
-
_ToolChoiceType,
|
| 109 |
-
_ToolConfigDict,
|
| 110 |
-
_ToolDict,
|
| 111 |
-
convert_to_genai_function_declarations,
|
| 112 |
-
is_basemodel_subclass_safe,
|
| 113 |
-
tool_to_dict,
|
| 114 |
-
)
|
| 115 |
-
from langchain_google_genai._image_utils import (
|
| 116 |
-
ImageBytesLoader,
|
| 117 |
-
image_bytes_to_b64_string,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
from . import _genai_extension as genaix
|
| 121 |
-
|
| 122 |
-
logger = logging.getLogger(__name__)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
_FunctionDeclarationType = Union[
|
| 126 |
-
FunctionDeclaration,
|
| 127 |
-
dict[str, Any],
|
| 128 |
-
Callable[..., Any],
|
| 129 |
-
]
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
|
| 133 |
-
"""
|
| 134 |
-
Custom exception class for errors associated with the `Google GenAI` API.
|
| 135 |
-
|
| 136 |
-
This exception is raised when there are specific issues related to the
|
| 137 |
-
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
|
| 138 |
-
message types or roles.
|
| 139 |
-
"""
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def _create_retry_decorator() -> Callable[[Any], Any]:
|
| 143 |
-
"""
|
| 144 |
-
Creates and returns a preconfigured tenacity retry decorator.
|
| 145 |
-
|
| 146 |
-
The retry decorator is configured to handle specific Google API exceptions
|
| 147 |
-
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
|
| 148 |
-
backoff strategy for retries.
|
| 149 |
-
|
| 150 |
-
Returns:
|
| 151 |
-
Callable[[Any], Any]: A retry decorator configured for handling specific
|
| 152 |
-
Google API exceptions.
|
| 153 |
-
"""
|
| 154 |
-
multiplier = 2
|
| 155 |
-
min_seconds = 1
|
| 156 |
-
max_seconds = 60
|
| 157 |
-
max_retries = 2
|
| 158 |
-
|
| 159 |
-
return retry(
|
| 160 |
-
reraise=True,
|
| 161 |
-
stop=stop_after_attempt(max_retries),
|
| 162 |
-
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
| 163 |
-
retry=(
|
| 164 |
-
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
| 165 |
-
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
| 166 |
-
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
| 167 |
-
),
|
| 168 |
-
before_sleep=before_sleep_log(logger, logging.WARNING),
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
| 173 |
-
"""
|
| 174 |
-
Executes a chat generation method with retry logic using tenacity.
|
| 175 |
-
|
| 176 |
-
This function is a wrapper that applies a retry mechanism to a provided
|
| 177 |
-
chat generation function. It is useful for handling intermittent issues
|
| 178 |
-
like network errors or temporary service unavailability.
|
| 179 |
-
|
| 180 |
-
Args:
|
| 181 |
-
generation_method (Callable): The chat generation method to be executed.
|
| 182 |
-
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
| 183 |
-
|
| 184 |
-
Returns:
|
| 185 |
-
Any: The result from the chat generation method.
|
| 186 |
-
"""
|
| 187 |
-
retry_decorator = _create_retry_decorator()
|
| 188 |
-
|
| 189 |
-
@retry_decorator
|
| 190 |
-
def _chat_with_retry(**kwargs: Any) -> Any:
|
| 191 |
-
try:
|
| 192 |
-
return generation_method(**kwargs)
|
| 193 |
-
# Do not retry for these errors.
|
| 194 |
-
except google.api_core.exceptions.FailedPrecondition as exc:
|
| 195 |
-
if "location is not supported" in exc.message:
|
| 196 |
-
error_msg = (
|
| 197 |
-
"Your location is not supported by google-generativeai "
|
| 198 |
-
"at the moment. Try to use ChatVertexAI LLM from "
|
| 199 |
-
"langchain_google_vertexai."
|
| 200 |
-
)
|
| 201 |
-
raise ValueError(error_msg)
|
| 202 |
-
|
| 203 |
-
except google.api_core.exceptions.InvalidArgument as e:
|
| 204 |
-
raise ChatGoogleGenerativeAIError(
|
| 205 |
-
f"Invalid argument provided to Gemini: {e}"
|
| 206 |
-
) from e
|
| 207 |
-
except Exception as e:
|
| 208 |
-
raise e
|
| 209 |
-
|
| 210 |
-
return _chat_with_retry(**kwargs)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
|
| 214 |
-
"""
|
| 215 |
-
Executes a chat generation method with retry logic using tenacity.
|
| 216 |
-
|
| 217 |
-
This function is a wrapper that applies a retry mechanism to a provided
|
| 218 |
-
chat generation function. It is useful for handling intermittent issues
|
| 219 |
-
like network errors or temporary service unavailability.
|
| 220 |
-
|
| 221 |
-
Args:
|
| 222 |
-
generation_method (Callable): The chat generation method to be executed.
|
| 223 |
-
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
| 224 |
-
|
| 225 |
-
Returns:
|
| 226 |
-
Any: The result from the chat generation method.
|
| 227 |
-
"""
|
| 228 |
-
retry_decorator = _create_retry_decorator()
|
| 229 |
-
from google.api_core.exceptions import InvalidArgument # type: ignore
|
| 230 |
-
|
| 231 |
-
@retry_decorator
|
| 232 |
-
async def _achat_with_retry(**kwargs: Any) -> Any:
|
| 233 |
-
try:
|
| 234 |
-
return await generation_method(**kwargs)
|
| 235 |
-
except InvalidArgument as e:
|
| 236 |
-
# Do not retry for these errors.
|
| 237 |
-
raise ChatGoogleGenerativeAIError(
|
| 238 |
-
f"Invalid argument provided to Gemini: {e}"
|
| 239 |
-
) from e
|
| 240 |
-
except Exception as e:
|
| 241 |
-
raise e
|
| 242 |
-
|
| 243 |
-
return await _achat_with_retry(**kwargs)
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def _is_lc_content_block(part: dict) -> bool:
|
| 247 |
-
return "type" in part
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def _is_openai_image_block(block: dict) -> bool:
|
| 251 |
-
"""Check if the block contains image data in OpenAI Chat Completions format."""
|
| 252 |
-
if block.get("type") == "image_url":
|
| 253 |
-
if (
|
| 254 |
-
(set(block.keys()) <= {"type", "image_url", "detail"})
|
| 255 |
-
and (image_url := block.get("image_url"))
|
| 256 |
-
and isinstance(image_url, dict)
|
| 257 |
-
):
|
| 258 |
-
url = image_url.get("url")
|
| 259 |
-
if isinstance(url, str):
|
| 260 |
-
return True
|
| 261 |
-
else:
|
| 262 |
-
return False
|
| 263 |
-
|
| 264 |
-
return False
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
"""Converts a list of LangChain messages into a google parts."""
|
| 271 |
-
parts = []
|
| 272 |
-
content = [raw_content] if isinstance(raw_content, str) else raw_content
|
| 273 |
-
image_loader = ImageBytesLoader()
|
| 274 |
-
for part in content:
|
| 275 |
-
if isinstance(part, str):
|
| 276 |
-
parts.append(Part(text=part))
|
| 277 |
-
elif isinstance(part, Mapping):
|
| 278 |
-
if _is_lc_content_block(part):
|
| 279 |
-
if part["type"] == "text":
|
| 280 |
-
parts.append(Part(text=part["text"]))
|
| 281 |
-
elif is_data_content_block(part):
|
| 282 |
-
if part["source_type"] == "url":
|
| 283 |
-
bytes_ = image_loader._bytes_from_url(part["url"])
|
| 284 |
-
elif part["source_type"] == "base64":
|
| 285 |
-
bytes_ = base64.b64decode(part["data"])
|
| 286 |
-
else:
|
| 287 |
-
raise ValueError("source_type must be url or base64.")
|
| 288 |
-
inline_data: dict = {"data": bytes_}
|
| 289 |
-
if "mime_type" in part:
|
| 290 |
-
inline_data["mime_type"] = part["mime_type"]
|
| 291 |
-
else:
|
| 292 |
-
source = cast(str, part.get("url") or part.get("data"))
|
| 293 |
-
mime_type, _ = mimetypes.guess_type(source)
|
| 294 |
-
if not mime_type:
|
| 295 |
-
kind = filetype.guess(bytes_)
|
| 296 |
-
if kind:
|
| 297 |
-
mime_type = kind.mime
|
| 298 |
-
if mime_type:
|
| 299 |
-
inline_data["mime_type"] = mime_type
|
| 300 |
-
parts.append(Part(inline_data=inline_data))
|
| 301 |
-
elif part["type"] == "image_url":
|
| 302 |
-
img_url = part["image_url"]
|
| 303 |
-
if isinstance(img_url, dict):
|
| 304 |
-
if "url" not in img_url:
|
| 305 |
-
raise ValueError(
|
| 306 |
-
f"Unrecognized message image format: {img_url}"
|
| 307 |
-
)
|
| 308 |
-
img_url = img_url["url"]
|
| 309 |
-
parts.append(image_loader.load_part(img_url))
|
| 310 |
-
# Handle media type like LangChain.js
|
| 311 |
-
# https://github.com/langchain-ai/langchainjs/blob/e536593e2585f1dd7b0afc187de4d07cb40689ba/libs/langchain-google-common/src/utils/gemini.ts#L93-L106
|
| 312 |
-
elif part["type"] == "media":
|
| 313 |
-
if "mime_type" not in part:
|
| 314 |
-
raise ValueError(f"Missing mime_type in media part: {part}")
|
| 315 |
-
mime_type = part["mime_type"]
|
| 316 |
-
media_part = Part()
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
)
|
| 322 |
-
elif "file_uri" in part:
|
| 323 |
-
media_part.file_data = FileData(
|
| 324 |
-
file_uri=part["file_uri"], mime_type=mime_type
|
| 325 |
-
)
|
| 326 |
-
else:
|
| 327 |
-
raise ValueError(
|
| 328 |
-
f"Media part must have either data or file_uri: {part}"
|
| 329 |
-
)
|
| 330 |
-
if "video_metadata" in part:
|
| 331 |
-
metadata = VideoMetadata(part["video_metadata"])
|
| 332 |
-
media_part.video_metadata = metadata
|
| 333 |
-
parts.append(media_part)
|
| 334 |
-
else:
|
| 335 |
-
raise ValueError(
|
| 336 |
-
f"Unrecognized message part type: {part['type']}. Only text, "
|
| 337 |
-
f"image_url, and media types are supported."
|
| 338 |
-
)
|
| 339 |
-
else:
|
| 340 |
-
# Yolo
|
| 341 |
-
logger.warning(
|
| 342 |
-
"Unrecognized message part format. Assuming it's a text part."
|
| 343 |
-
)
|
| 344 |
-
parts.append(Part(text=str(part)))
|
| 345 |
-
else:
|
| 346 |
-
# TODO: Maybe some of Google's native stuff
|
| 347 |
-
# would hit this branch.
|
| 348 |
-
raise ChatGoogleGenerativeAIError(
|
| 349 |
-
"Gemini only supports text and inline_data parts."
|
| 350 |
-
)
|
| 351 |
-
return parts
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
def _convert_tool_message_to_parts(
|
| 355 |
-
message: ToolMessage | FunctionMessage, name: Optional[str] = None
|
| 356 |
-
) -> list[Part]:
|
| 357 |
-
"""Converts a tool or function message to a google part."""
|
| 358 |
-
# Legacy agent stores tool name in message.additional_kwargs instead of message.name
|
| 359 |
-
name = message.name or name or message.additional_kwargs.get("name")
|
| 360 |
-
response: Any
|
| 361 |
-
parts: list[Part] = []
|
| 362 |
-
if isinstance(message.content, list):
|
| 363 |
-
media_blocks = []
|
| 364 |
-
other_blocks = []
|
| 365 |
-
for block in message.content:
|
| 366 |
-
if isinstance(block, dict) and (
|
| 367 |
-
is_data_content_block(block) or _is_openai_image_block(block)
|
| 368 |
-
):
|
| 369 |
-
media_blocks.append(block)
|
| 370 |
-
else:
|
| 371 |
-
other_blocks.append(block)
|
| 372 |
-
parts.extend(_convert_to_parts(media_blocks))
|
| 373 |
-
response = other_blocks
|
| 374 |
-
|
| 375 |
-
elif not isinstance(message.content, str):
|
| 376 |
-
response = message.content
|
| 377 |
-
else:
|
| 378 |
-
try:
|
| 379 |
-
response = json.loads(message.content)
|
| 380 |
-
except json.JSONDecodeError:
|
| 381 |
-
response = message.content # leave as str representation
|
| 382 |
-
part = Part(
|
| 383 |
-
function_response=FunctionResponse(
|
| 384 |
-
name=name,
|
| 385 |
-
response=(
|
| 386 |
-
{"output": response} if not isinstance(response, dict) else response
|
| 387 |
-
),
|
| 388 |
-
)
|
| 389 |
-
)
|
| 390 |
-
parts.append(part)
|
| 391 |
-
return parts
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
def _get_ai_message_tool_messages_parts(
|
| 395 |
-
tool_messages: Sequence[ToolMessage], ai_message: AIMessage
|
| 396 |
-
) -> list[Part]:
|
| 397 |
-
"""
|
| 398 |
-
Finds relevant tool messages for the AI message and converts them to a single
|
| 399 |
-
list of Parts.
|
| 400 |
-
"""
|
| 401 |
-
# We are interested only in the tool messages that are part of the AI message
|
| 402 |
-
tool_calls_ids = {tool_call["id"]: tool_call for tool_call in ai_message.tool_calls}
|
| 403 |
-
parts = []
|
| 404 |
-
for i, message in enumerate(tool_messages):
|
| 405 |
-
if not tool_calls_ids:
|
| 406 |
-
break
|
| 407 |
-
if message.tool_call_id in tool_calls_ids:
|
| 408 |
-
tool_call = tool_calls_ids[message.tool_call_id]
|
| 409 |
-
message_parts = _convert_tool_message_to_parts(
|
| 410 |
-
message, name=tool_call.get("name")
|
| 411 |
-
)
|
| 412 |
-
parts.extend(message_parts)
|
| 413 |
-
# remove the id from the dict, so that we do not iterate over it again
|
| 414 |
-
tool_calls_ids.pop(message.tool_call_id)
|
| 415 |
-
return parts
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
def _parse_chat_history(
|
| 419 |
-
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
|
| 420 |
-
) -> Tuple[Optional[Content], List[Content]]:
|
| 421 |
-
messages: List[Content] = []
|
| 422 |
-
|
| 423 |
-
if convert_system_message_to_human:
|
| 424 |
-
warnings.warn("Convert_system_message_to_human will be deprecated!")
|
| 425 |
-
|
| 426 |
-
system_instruction: Optional[Content] = None
|
| 427 |
-
messages_without_tool_messages = [
|
| 428 |
-
message for message in input_messages if not isinstance(message, ToolMessage)
|
| 429 |
-
]
|
| 430 |
-
tool_messages = [
|
| 431 |
-
message for message in input_messages if isinstance(message, ToolMessage)
|
| 432 |
-
]
|
| 433 |
-
for i, message in enumerate(messages_without_tool_messages):
|
| 434 |
-
if isinstance(message, SystemMessage):
|
| 435 |
-
system_parts = _convert_to_parts(message.content)
|
| 436 |
-
if i == 0:
|
| 437 |
-
system_instruction = Content(parts=system_parts)
|
| 438 |
-
elif system_instruction is not None:
|
| 439 |
-
system_instruction.parts.extend(system_parts)
|
| 440 |
-
else:
|
| 441 |
-
pass
|
| 442 |
-
continue
|
| 443 |
-
elif isinstance(message, AIMessage):
|
| 444 |
-
role = "model"
|
| 445 |
-
if message.tool_calls:
|
| 446 |
-
ai_message_parts = []
|
| 447 |
-
for tool_call in message.tool_calls:
|
| 448 |
-
function_call = FunctionCall(
|
| 449 |
-
{
|
| 450 |
-
"name": tool_call["name"],
|
| 451 |
-
"args": tool_call["args"],
|
| 452 |
-
}
|
| 453 |
-
)
|
| 454 |
-
ai_message_parts.append(Part(function_call=function_call))
|
| 455 |
-
tool_messages_parts = _get_ai_message_tool_messages_parts(
|
| 456 |
-
tool_messages=tool_messages, ai_message=message
|
| 457 |
-
)
|
| 458 |
-
messages.append(Content(role=role, parts=ai_message_parts))
|
| 459 |
-
messages.append(Content(role="user", parts=tool_messages_parts))
|
| 460 |
-
continue
|
| 461 |
-
elif raw_function_call := message.additional_kwargs.get("function_call"):
|
| 462 |
-
function_call = FunctionCall(
|
| 463 |
-
{
|
| 464 |
-
"name": raw_function_call["name"],
|
| 465 |
-
"args": json.loads(raw_function_call["arguments"]),
|
| 466 |
-
}
|
| 467 |
-
)
|
| 468 |
-
parts = [Part(function_call=function_call)]
|
| 469 |
-
else:
|
| 470 |
-
parts = _convert_to_parts(message.content)
|
| 471 |
-
elif isinstance(message, HumanMessage):
|
| 472 |
-
role = "user"
|
| 473 |
-
parts = _convert_to_parts(message.content)
|
| 474 |
-
if i == 1 and convert_system_message_to_human and system_instruction:
|
| 475 |
-
parts = [p for p in system_instruction.parts] + parts
|
| 476 |
-
system_instruction = None
|
| 477 |
-
elif isinstance(message, FunctionMessage):
|
| 478 |
-
role = "user"
|
| 479 |
-
parts = _convert_tool_message_to_parts(message)
|
| 480 |
-
else:
|
| 481 |
-
raise ValueError(
|
| 482 |
-
f"Unexpected message with type {type(message)} at the position {i}."
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
messages.append(Content(role=role, parts=parts))
|
| 486 |
-
return system_instruction, messages
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
def _parse_response_candidate(
|
| 490 |
-
response_candidate: Candidate, streaming: bool = False
|
| 491 |
-
) -> AIMessage:
|
| 492 |
-
content: Union[None, str, List[Union[str, dict]]] = None
|
| 493 |
-
additional_kwargs = {}
|
| 494 |
-
tool_calls = []
|
| 495 |
-
invalid_tool_calls = []
|
| 496 |
-
tool_call_chunks = []
|
| 497 |
-
|
| 498 |
-
for part in response_candidate.content.parts:
|
| 499 |
-
try:
|
| 500 |
-
text: Optional[str] = part.text
|
| 501 |
-
# Remove erroneous newline character if present
|
| 502 |
-
if not streaming and text is not None:
|
| 503 |
-
text = text.rstrip("\n")
|
| 504 |
-
except AttributeError:
|
| 505 |
-
text = None
|
| 506 |
-
|
| 507 |
-
if part.thought:
|
| 508 |
-
thinking_message = {
|
| 509 |
-
"type": "thinking",
|
| 510 |
-
"thinking": part.text,
|
| 511 |
-
}
|
| 512 |
-
if not content:
|
| 513 |
-
content = [thinking_message]
|
| 514 |
-
elif isinstance(content, str):
|
| 515 |
-
content = [thinking_message, content]
|
| 516 |
-
elif isinstance(content, list):
|
| 517 |
-
content.append(thinking_message)
|
| 518 |
-
else:
|
| 519 |
-
raise Exception("Unexpected content type")
|
| 520 |
-
|
| 521 |
-
elif text is not None:
|
| 522 |
-
if not content:
|
| 523 |
-
content = text
|
| 524 |
-
elif isinstance(content, str) and text:
|
| 525 |
-
content = [content, text]
|
| 526 |
-
elif isinstance(content, list) and text:
|
| 527 |
-
content.append(text)
|
| 528 |
-
elif text:
|
| 529 |
-
raise Exception("Unexpected content type")
|
| 530 |
-
|
| 531 |
-
if hasattr(part, "executable_code") and part.executable_code is not None:
|
| 532 |
-
if part.executable_code.code and part.executable_code.language:
|
| 533 |
-
code_message = {
|
| 534 |
-
"type": "executable_code",
|
| 535 |
-
"executable_code": part.executable_code.code,
|
| 536 |
-
"language": part.executable_code.language,
|
| 537 |
-
}
|
| 538 |
-
if not content:
|
| 539 |
-
content = [code_message]
|
| 540 |
-
elif isinstance(content, str):
|
| 541 |
-
content = [content, code_message]
|
| 542 |
-
elif isinstance(content, list):
|
| 543 |
-
content.append(code_message)
|
| 544 |
-
else:
|
| 545 |
-
raise Exception("Unexpected content type")
|
| 546 |
-
|
| 547 |
-
if (
|
| 548 |
-
hasattr(part, "code_execution_result")
|
| 549 |
-
and part.code_execution_result is not None
|
| 550 |
-
):
|
| 551 |
-
if part.code_execution_result.output:
|
| 552 |
-
execution_result = {
|
| 553 |
-
"type": "code_execution_result",
|
| 554 |
-
"code_execution_result": part.code_execution_result.output,
|
| 555 |
-
}
|
| 556 |
-
|
| 557 |
-
if not content:
|
| 558 |
-
content = [execution_result]
|
| 559 |
-
elif isinstance(content, str):
|
| 560 |
-
content = [content, execution_result]
|
| 561 |
-
elif isinstance(content, list):
|
| 562 |
-
content.append(execution_result)
|
| 563 |
-
else:
|
| 564 |
-
raise Exception("Unexpected content type")
|
| 565 |
|
| 566 |
-
|
| 567 |
-
image_format = part.inline_data.mime_type[6:]
|
| 568 |
-
message = {
|
| 569 |
-
"type": "image_url",
|
| 570 |
-
"image_url": {
|
| 571 |
-
"url": image_bytes_to_b64_string(
|
| 572 |
-
part.inline_data.data, image_format=image_format
|
| 573 |
-
)
|
| 574 |
-
},
|
| 575 |
-
}
|
| 576 |
-
|
| 577 |
-
if not content:
|
| 578 |
-
content = [message]
|
| 579 |
-
elif isinstance(content, str) and message:
|
| 580 |
-
content = [content, message]
|
| 581 |
-
elif isinstance(content, list) and message:
|
| 582 |
-
content.append(message)
|
| 583 |
-
elif message:
|
| 584 |
-
raise Exception("Unexpected content type")
|
| 585 |
-
|
| 586 |
-
if part.function_call:
|
| 587 |
-
function_call = {"name": part.function_call.name}
|
| 588 |
-
# dump to match other function calling llm for now
|
| 589 |
-
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
|
| 590 |
-
function_call["arguments"] = json.dumps(
|
| 591 |
-
{k: function_call_args_dict[k] for k in function_call_args_dict}
|
| 592 |
-
)
|
| 593 |
-
additional_kwargs["function_call"] = function_call
|
| 594 |
-
|
| 595 |
-
if streaming:
|
| 596 |
-
tool_call_chunks.append(
|
| 597 |
-
tool_call_chunk(
|
| 598 |
-
name=function_call.get("name"),
|
| 599 |
-
args=function_call.get("arguments"),
|
| 600 |
-
id=function_call.get("id", str(uuid.uuid4())),
|
| 601 |
-
index=function_call.get("index"), # type: ignore
|
| 602 |
-
)
|
| 603 |
-
)
|
| 604 |
-
else:
|
| 605 |
-
try:
|
| 606 |
-
tool_call_dict = parse_tool_calls(
|
| 607 |
-
[{"function": function_call}],
|
| 608 |
-
return_id=False,
|
| 609 |
-
)[0]
|
| 610 |
-
except Exception as e:
|
| 611 |
-
invalid_tool_calls.append(
|
| 612 |
-
invalid_tool_call(
|
| 613 |
-
name=function_call.get("name"),
|
| 614 |
-
args=function_call.get("arguments"),
|
| 615 |
-
id=function_call.get("id", str(uuid.uuid4())),
|
| 616 |
-
error=str(e),
|
| 617 |
-
)
|
| 618 |
-
)
|
| 619 |
-
else:
|
| 620 |
-
tool_calls.append(
|
| 621 |
-
tool_call(
|
| 622 |
-
name=tool_call_dict["name"],
|
| 623 |
-
args=tool_call_dict["args"],
|
| 624 |
-
id=tool_call_dict.get("id", str(uuid.uuid4())),
|
| 625 |
-
)
|
| 626 |
-
)
|
| 627 |
-
if content is None:
|
| 628 |
-
content = ""
|
| 629 |
-
if any(isinstance(item, dict) and "executable_code" in item for item in content):
|
| 630 |
-
warnings.warn(
|
| 631 |
-
"""
|
| 632 |
-
⚠️ Warning: Output may vary each run.
|
| 633 |
-
- 'executable_code': Always present.
|
| 634 |
-
- 'execution_result' & 'image_url': May be absent for some queries.
|
| 635 |
-
|
| 636 |
-
Validate before using in production.
|
| 637 |
"""
|
| 638 |
-
)
|
| 639 |
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
"""Converts a PaLM API response into a LangChain ChatResult."""
|
| 661 |
-
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
| 662 |
-
|
| 663 |
-
# previous usage metadata needs to be subtracted because gemini api returns
|
| 664 |
-
# already-accumulated token counts with each chunk
|
| 665 |
-
prev_input_tokens = prev_usage["input_tokens"] if prev_usage else 0
|
| 666 |
-
prev_output_tokens = prev_usage["output_tokens"] if prev_usage else 0
|
| 667 |
-
prev_total_tokens = prev_usage["total_tokens"] if prev_usage else 0
|
| 668 |
|
| 669 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
try:
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
if
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
try:
|
| 710 |
-
if candidate.grounding_metadata:
|
| 711 |
-
generation_info["grounding_metadata"] = proto.Message.to_dict(
|
| 712 |
-
candidate.grounding_metadata
|
| 713 |
-
)
|
| 714 |
-
except AttributeError:
|
| 715 |
-
pass
|
| 716 |
-
message = _parse_response_candidate(candidate, streaming=stream)
|
| 717 |
-
message.usage_metadata = lc_usage
|
| 718 |
-
if stream:
|
| 719 |
-
generations.append(
|
| 720 |
-
ChatGenerationChunk(
|
| 721 |
-
message=cast(AIMessageChunk, message),
|
| 722 |
-
generation_info=generation_info,
|
| 723 |
-
)
|
| 724 |
-
)
|
| 725 |
-
else:
|
| 726 |
-
generations.append(
|
| 727 |
-
ChatGeneration(message=message, generation_info=generation_info)
|
| 728 |
-
)
|
| 729 |
-
if not response.candidates:
|
| 730 |
-
# Likely a "prompt feedback" violation (e.g., toxic input)
|
| 731 |
-
# Raising an error would be different than how OpenAI handles it,
|
| 732 |
-
# so we'll just log a warning and continue with an empty message.
|
| 733 |
-
logger.warning(
|
| 734 |
-
"Gemini produced an empty response. Continuing with empty message\n"
|
| 735 |
-
f"Feedback: {response.prompt_feedback}"
|
| 736 |
-
)
|
| 737 |
-
if stream:
|
| 738 |
-
generations = [
|
| 739 |
-
ChatGenerationChunk(
|
| 740 |
-
message=AIMessageChunk(content=""), generation_info={}
|
| 741 |
-
)
|
| 742 |
-
]
|
| 743 |
-
else:
|
| 744 |
-
generations = [ChatGeneration(message=AIMessage(""), generation_info={})]
|
| 745 |
-
return ChatResult(generations=generations, llm_output=llm_output)
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
def _is_event_loop_running() -> bool:
|
| 749 |
try:
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
"""
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
|
| 763 |
-
2. Pass your API key using the google_api_key kwarg
|
| 764 |
-
to the ChatGoogleGenerativeAI constructor.
|
| 765 |
-
|
| 766 |
-
.. code-block:: python
|
| 767 |
-
|
| 768 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 769 |
-
|
| 770 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001")
|
| 771 |
-
llm.invoke("Write me a ballad about LangChain")
|
| 772 |
-
|
| 773 |
-
Invoke:
|
| 774 |
-
.. code-block:: python
|
| 775 |
-
|
| 776 |
-
messages = [
|
| 777 |
-
("system", "Translate the user sentence to French."),
|
| 778 |
-
("human", "I love programming."),
|
| 779 |
-
]
|
| 780 |
-
llm.invoke(messages)
|
| 781 |
-
|
| 782 |
-
.. code-block:: python
|
| 783 |
-
|
| 784 |
-
AIMessage(
|
| 785 |
-
content="J'adore programmer. \\n",
|
| 786 |
-
response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
|
| 787 |
-
id='run-56cecc34-2e54-4b52-a974-337e47008ad2-0',
|
| 788 |
-
usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
|
| 789 |
-
)
|
| 790 |
-
|
| 791 |
-
Stream:
|
| 792 |
-
.. code-block:: python
|
| 793 |
-
|
| 794 |
-
for chunk in llm.stream(messages):
|
| 795 |
-
print(chunk)
|
| 796 |
-
|
| 797 |
-
.. code-block:: python
|
| 798 |
-
|
| 799 |
-
AIMessageChunk(content='J', response_metadata={'finish_reason': 'STOP', 'safety_ratings': []}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 1, 'total_tokens': 19})
|
| 800 |
-
AIMessageChunk(content="'adore programmer. \\n", response_metadata={'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23})
|
| 801 |
-
|
| 802 |
-
.. code-block:: python
|
| 803 |
-
|
| 804 |
-
stream = llm.stream(messages)
|
| 805 |
-
full = next(stream)
|
| 806 |
-
for chunk in stream:
|
| 807 |
-
full += chunk
|
| 808 |
-
full
|
| 809 |
-
|
| 810 |
-
.. code-block:: python
|
| 811 |
-
|
| 812 |
-
AIMessageChunk(
|
| 813 |
-
content="J'adore programmer. \\n",
|
| 814 |
-
response_metadata={'finish_reason': 'STOPSTOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
|
| 815 |
-
id='run-3ce13a42-cd30-4ad7-a684-f1f0b37cdeec',
|
| 816 |
-
usage_metadata={'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}
|
| 817 |
-
)
|
| 818 |
-
|
| 819 |
-
Async:
|
| 820 |
-
.. code-block:: python
|
| 821 |
-
|
| 822 |
-
await llm.ainvoke(messages)
|
| 823 |
-
|
| 824 |
-
# stream:
|
| 825 |
-
# async for chunk in (await llm.astream(messages))
|
| 826 |
-
|
| 827 |
-
# batch:
|
| 828 |
-
# await llm.abatch([messages])
|
| 829 |
-
|
| 830 |
-
Context Caching:
|
| 831 |
-
Context caching allows you to store and reuse content (e.g., PDFs, images) for faster processing.
|
| 832 |
-
The `cached_content` parameter accepts a cache name created via the Google Generative AI API.
|
| 833 |
-
Below are two examples: caching a single file directly and caching multiple files using `Part`.
|
| 834 |
-
|
| 835 |
-
Single File Example:
|
| 836 |
-
This caches a single file and queries it.
|
| 837 |
-
|
| 838 |
-
.. code-block:: python
|
| 839 |
-
|
| 840 |
-
from google import genai
|
| 841 |
-
from google.genai import types
|
| 842 |
-
import time
|
| 843 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 844 |
-
from langchain_core.messages import HumanMessage
|
| 845 |
-
|
| 846 |
-
client = genai.Client()
|
| 847 |
-
|
| 848 |
-
# Upload file
|
| 849 |
-
file = client.files.upload(file="./example_file")
|
| 850 |
-
while file.state.name == 'PROCESSING':
|
| 851 |
-
time.sleep(2)
|
| 852 |
-
file = client.files.get(name=file.name)
|
| 853 |
-
|
| 854 |
-
# Create cache
|
| 855 |
-
model = 'models/gemini-1.5-flash-latest'
|
| 856 |
-
cache = client.caches.create(
|
| 857 |
-
model=model,
|
| 858 |
-
config=types.CreateCachedContentConfig(
|
| 859 |
-
display_name='Cached Content',
|
| 860 |
-
system_instruction=(
|
| 861 |
-
'You are an expert content analyzer, and your job is to answer '
|
| 862 |
-
'the user\'s query based on the file you have access to.'
|
| 863 |
-
),
|
| 864 |
-
contents=[file],
|
| 865 |
-
ttl="300s",
|
| 866 |
-
)
|
| 867 |
-
)
|
| 868 |
-
|
| 869 |
-
# Query with LangChain
|
| 870 |
-
llm = ChatGoogleGenerativeAI(
|
| 871 |
-
model=model,
|
| 872 |
-
cached_content=cache.name,
|
| 873 |
-
)
|
| 874 |
-
message = HumanMessage(content="Summarize the main points of the content.")
|
| 875 |
-
llm.invoke([message])
|
| 876 |
-
|
| 877 |
-
Multiple Files Example:
|
| 878 |
-
This caches two files using `Part` and queries them together.
|
| 879 |
-
|
| 880 |
-
.. code-block:: python
|
| 881 |
-
|
| 882 |
-
from google import genai
|
| 883 |
-
from google.genai.types import CreateCachedContentConfig, Content, Part
|
| 884 |
-
import time
|
| 885 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 886 |
-
from langchain_core.messages import HumanMessage
|
| 887 |
-
|
| 888 |
-
client = genai.Client()
|
| 889 |
-
|
| 890 |
-
# Upload files
|
| 891 |
-
file_1 = client.files.upload(file="./file1")
|
| 892 |
-
while file_1.state.name == 'PROCESSING':
|
| 893 |
-
time.sleep(2)
|
| 894 |
-
file_1 = client.files.get(name=file_1.name)
|
| 895 |
-
|
| 896 |
-
file_2 = client.files.upload(file="./file2")
|
| 897 |
-
while file_2.state.name == 'PROCESSING':
|
| 898 |
-
time.sleep(2)
|
| 899 |
-
file_2 = client.files.get(name=file_2.name)
|
| 900 |
-
|
| 901 |
-
# Create cache with multiple files
|
| 902 |
-
contents = [
|
| 903 |
-
Content(
|
| 904 |
-
role="user",
|
| 905 |
-
parts=[
|
| 906 |
-
Part.from_uri(file_uri=file_1.uri, mime_type=file_1.mime_type),
|
| 907 |
-
Part.from_uri(file_uri=file_2.uri, mime_type=file_2.mime_type),
|
| 908 |
-
],
|
| 909 |
-
)
|
| 910 |
-
]
|
| 911 |
-
model = "gemini-1.5-flash-latest"
|
| 912 |
-
cache = client.caches.create(
|
| 913 |
-
model=model,
|
| 914 |
-
config=CreateCachedContentConfig(
|
| 915 |
-
display_name='Cached Contents',
|
| 916 |
-
system_instruction=(
|
| 917 |
-
'You are an expert content analyzer, and your job is to answer '
|
| 918 |
-
'the user\'s query based on the files you have access to.'
|
| 919 |
-
),
|
| 920 |
-
contents=contents,
|
| 921 |
-
ttl="300s",
|
| 922 |
-
)
|
| 923 |
-
)
|
| 924 |
-
|
| 925 |
-
# Query with LangChain
|
| 926 |
llm = ChatGoogleGenerativeAI(
|
| 927 |
-
model=
|
| 928 |
-
|
|
|
|
| 929 |
)
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
class GetWeather(BaseModel):
|
| 940 |
-
'''Get the current weather in a given location'''
|
| 941 |
-
|
| 942 |
-
location: str = Field(
|
| 943 |
-
..., description="The city and state, e.g. San Francisco, CA"
|
| 944 |
-
)
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
class GetPopulation(BaseModel):
|
| 948 |
-
'''Get the current population in a given location'''
|
| 949 |
-
|
| 950 |
-
location: str = Field(
|
| 951 |
-
..., description="The city and state, e.g. San Francisco, CA"
|
| 952 |
-
)
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
| 956 |
-
ai_msg = llm_with_tools.invoke(
|
| 957 |
-
"Which city is hotter today and which is bigger: LA or NY?"
|
| 958 |
-
)
|
| 959 |
-
ai_msg.tool_calls
|
| 960 |
-
|
| 961 |
-
.. code-block:: python
|
| 962 |
-
|
| 963 |
-
[{'name': 'GetWeather',
|
| 964 |
-
'args': {'location': 'Los Angeles, CA'},
|
| 965 |
-
'id': 'c186c99f-f137-4d52-947f-9e3deabba6f6'},
|
| 966 |
-
{'name': 'GetWeather',
|
| 967 |
-
'args': {'location': 'New York City, NY'},
|
| 968 |
-
'id': 'cebd4a5d-e800-4fa5-babd-4aa286af4f31'},
|
| 969 |
-
{'name': 'GetPopulation',
|
| 970 |
-
'args': {'location': 'Los Angeles, CA'},
|
| 971 |
-
'id': '4f92d897-f5e4-4d34-a3bc-93062c92591e'},
|
| 972 |
-
{'name': 'GetPopulation',
|
| 973 |
-
'args': {'location': 'New York City, NY'},
|
| 974 |
-
'id': '634582de-5186-4e4b-968b-f192f0a93678'}]
|
| 975 |
-
|
| 976 |
-
Use Search with Gemini 2:
|
| 977 |
-
.. code-block:: python
|
| 978 |
-
|
| 979 |
-
from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
|
| 980 |
-
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
|
| 981 |
-
resp = llm.invoke(
|
| 982 |
-
"When is the next total solar eclipse in US?",
|
| 983 |
-
tools=[GenAITool(google_search={})],
|
| 984 |
-
)
|
| 985 |
-
|
| 986 |
-
Structured output:
|
| 987 |
-
.. code-block:: python
|
| 988 |
-
|
| 989 |
-
from typing import Optional
|
| 990 |
-
|
| 991 |
-
from pydantic import BaseModel, Field
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
class Joke(BaseModel):
|
| 995 |
-
'''Joke to tell user.'''
|
| 996 |
-
|
| 997 |
-
setup: str = Field(description="The setup of the joke")
|
| 998 |
-
punchline: str = Field(description="The punchline to the joke")
|
| 999 |
-
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
structured_llm = llm.with_structured_output(Joke)
|
| 1003 |
-
structured_llm.invoke("Tell me a joke about cats")
|
| 1004 |
-
|
| 1005 |
-
.. code-block:: python
|
| 1006 |
-
|
| 1007 |
-
Joke(
|
| 1008 |
-
setup='Why are cats so good at video games?',
|
| 1009 |
-
punchline='They have nine lives on the internet',
|
| 1010 |
-
rating=None
|
| 1011 |
-
)
|
| 1012 |
-
|
| 1013 |
-
Image input:
|
| 1014 |
-
.. code-block:: python
|
| 1015 |
-
|
| 1016 |
-
import base64
|
| 1017 |
-
import httpx
|
| 1018 |
-
from langchain_core.messages import HumanMessage
|
| 1019 |
-
|
| 1020 |
-
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
| 1021 |
-
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
|
| 1022 |
-
message = HumanMessage(
|
| 1023 |
-
content=[
|
| 1024 |
-
{"type": "text", "text": "describe the weather in this image"},
|
| 1025 |
-
{
|
| 1026 |
-
"type": "image_url",
|
| 1027 |
-
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
|
| 1028 |
-
},
|
| 1029 |
-
]
|
| 1030 |
)
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
ai_msg = llm.invoke(messages)
|
| 1053 |
-
ai_msg.response_metadata
|
| 1054 |
-
|
| 1055 |
-
.. code-block:: python
|
| 1056 |
-
|
| 1057 |
-
{
|
| 1058 |
-
'prompt_feedback': {'block_reason': 0, 'safety_ratings': []},
|
| 1059 |
-
'finish_reason': 'STOP',
|
| 1060 |
-
'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]
|
| 1061 |
-
}
|
| 1062 |
-
|
| 1063 |
-
""" # noqa: E501
|
| 1064 |
-
|
| 1065 |
-
client: Any = Field(default=None, exclude=True) #: :meta private:
|
| 1066 |
-
async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
|
| 1067 |
-
default_metadata: Sequence[Tuple[str, str]] = Field(
|
| 1068 |
-
default_factory=list
|
| 1069 |
-
) #: :meta private:
|
| 1070 |
-
|
| 1071 |
-
convert_system_message_to_human: bool = False
|
| 1072 |
-
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
| 1073 |
-
|
| 1074 |
-
Gemini does not support system messages; any unsupported messages will
|
| 1075 |
-
raise an error."""
|
| 1076 |
-
|
| 1077 |
-
cached_content: Optional[str] = None
|
| 1078 |
-
"""The name of the cached content used as context to serve the prediction.
|
| 1079 |
-
|
| 1080 |
-
Note: only used in explicit caching, where users can have control over caching
|
| 1081 |
-
(e.g. what content to cache) and enjoy guaranteed cost savings. Format:
|
| 1082 |
-
``cachedContents/{cachedContent}``.
|
| 1083 |
-
"""
|
| 1084 |
-
|
| 1085 |
-
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
| 1086 |
-
"""Holds any unexpected initialization parameters."""
|
| 1087 |
-
|
| 1088 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 1089 |
-
"""Needed for arg validation."""
|
| 1090 |
-
# Get all valid field names, including aliases
|
| 1091 |
-
valid_fields = set()
|
| 1092 |
-
for field_name, field_info in self.__class__.model_fields.items():
|
| 1093 |
-
valid_fields.add(field_name)
|
| 1094 |
-
if hasattr(field_info, "alias") and field_info.alias is not None:
|
| 1095 |
-
valid_fields.add(field_info.alias)
|
| 1096 |
-
|
| 1097 |
-
# Check for unrecognized arguments
|
| 1098 |
-
for arg in kwargs:
|
| 1099 |
-
if arg not in valid_fields:
|
| 1100 |
-
suggestions = get_close_matches(arg, valid_fields, n=1)
|
| 1101 |
-
suggestion = (
|
| 1102 |
-
f" Did you mean: '{suggestions[0]}'?" if suggestions else ""
|
| 1103 |
-
)
|
| 1104 |
-
logger.warning(
|
| 1105 |
-
f"Unexpected argument '{arg}' "
|
| 1106 |
-
f"provided to ChatGoogleGenerativeAI.{suggestion}"
|
| 1107 |
-
)
|
| 1108 |
-
super().__init__(**kwargs)
|
| 1109 |
-
|
| 1110 |
-
model_config = ConfigDict(
|
| 1111 |
-
populate_by_name=True,
|
| 1112 |
-
)
|
| 1113 |
-
|
| 1114 |
-
@property
|
| 1115 |
-
def lc_secrets(self) -> Dict[str, str]:
|
| 1116 |
-
return {"google_api_key": "GOOGLE_API_KEY"}
|
| 1117 |
-
|
| 1118 |
-
@property
|
| 1119 |
-
def _llm_type(self) -> str:
|
| 1120 |
-
return "chat-google-generative-ai"
|
| 1121 |
-
|
| 1122 |
-
@property
|
| 1123 |
-
def _supports_code_execution(self) -> bool:
|
| 1124 |
-
return (
|
| 1125 |
-
"gemini-1.5-pro" in self.model
|
| 1126 |
-
or "gemini-1.5-flash" in self.model
|
| 1127 |
-
or "gemini-2" in self.model
|
| 1128 |
-
)
|
| 1129 |
-
|
| 1130 |
-
@classmethod
|
| 1131 |
-
def is_lc_serializable(self) -> bool:
|
| 1132 |
-
return True
|
| 1133 |
-
|
| 1134 |
-
@model_validator(mode="before")
|
| 1135 |
-
@classmethod
|
| 1136 |
-
def build_extra(cls, values: dict[str, Any]) -> Any:
|
| 1137 |
-
"""Build extra kwargs from additional params that were passed in."""
|
| 1138 |
-
all_required_field_names = get_pydantic_field_names(cls)
|
| 1139 |
-
values = _build_model_kwargs(values, all_required_field_names)
|
| 1140 |
-
return values
|
| 1141 |
-
|
| 1142 |
-
@model_validator(mode="after")
|
| 1143 |
-
def validate_environment(self) -> Self:
|
| 1144 |
-
"""Validates params and passes them to google-generativeai package."""
|
| 1145 |
-
if self.temperature is not None and not 0 <= self.temperature <= 2.0:
|
| 1146 |
-
raise ValueError("temperature must be in the range [0.0, 2.0]")
|
| 1147 |
-
|
| 1148 |
-
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
| 1149 |
-
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
| 1150 |
-
|
| 1151 |
-
if self.top_k is not None and self.top_k <= 0:
|
| 1152 |
-
raise ValueError("top_k must be positive")
|
| 1153 |
-
|
| 1154 |
-
if not any(
|
| 1155 |
-
self.model.startswith(prefix) for prefix in ("models/", "tunedModels/")
|
| 1156 |
-
):
|
| 1157 |
-
self.model = f"models/{self.model}"
|
| 1158 |
-
|
| 1159 |
-
additional_headers = self.additional_headers or {}
|
| 1160 |
-
self.default_metadata = tuple(additional_headers.items())
|
| 1161 |
-
client_info = get_client_info(f"ChatGoogleGenerativeAI:{self.model}")
|
| 1162 |
-
google_api_key = None
|
| 1163 |
-
if not self.credentials:
|
| 1164 |
-
if isinstance(self.google_api_key, SecretStr):
|
| 1165 |
-
google_api_key = self.google_api_key.get_secret_value()
|
| 1166 |
-
else:
|
| 1167 |
-
google_api_key = self.google_api_key
|
| 1168 |
-
transport: Optional[str] = self.transport
|
| 1169 |
-
self.client = genaix.build_generative_service(
|
| 1170 |
-
credentials=self.credentials,
|
| 1171 |
-
api_key=google_api_key,
|
| 1172 |
-
client_info=client_info,
|
| 1173 |
-
client_options=self.client_options,
|
| 1174 |
-
transport=transport,
|
| 1175 |
)
|
| 1176 |
-
|
| 1177 |
-
return self
|
| 1178 |
-
|
| 1179 |
-
@property
|
| 1180 |
-
def async_client(self) -> v1betaGenerativeServiceAsyncClient:
|
| 1181 |
-
google_api_key = None
|
| 1182 |
-
if not self.credentials:
|
| 1183 |
-
if isinstance(self.google_api_key, SecretStr):
|
| 1184 |
-
google_api_key = self.google_api_key.get_secret_value()
|
| 1185 |
-
else:
|
| 1186 |
-
google_api_key = self.google_api_key
|
| 1187 |
-
# NOTE: genaix.build_generative_async_service requires
|
| 1188 |
-
# a running event loop, which causes an error
|
| 1189 |
-
# when initialized inside a ThreadPoolExecutor.
|
| 1190 |
-
# this check ensures that async client is only initialized
|
| 1191 |
-
# within an asyncio event loop to avoid the error
|
| 1192 |
-
if not self.async_client_running and _is_event_loop_running():
|
| 1193 |
-
# async clients don't support "rest" transport
|
| 1194 |
-
# https://github.com/googleapis/gapic-generator-python/issues/1962
|
| 1195 |
-
transport = self.transport
|
| 1196 |
-
if transport == "rest":
|
| 1197 |
-
transport = "grpc_asyncio"
|
| 1198 |
-
self.async_client_running = genaix.build_generative_async_service(
|
| 1199 |
-
credentials=self.credentials,
|
| 1200 |
-
api_key=google_api_key,
|
| 1201 |
-
client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"),
|
| 1202 |
-
client_options=self.client_options,
|
| 1203 |
-
transport=transport,
|
| 1204 |
-
)
|
| 1205 |
-
return self.async_client_running
|
| 1206 |
-
|
| 1207 |
-
@property
|
| 1208 |
-
def _identifying_params(self) -> Dict[str, Any]:
|
| 1209 |
-
"""Get the identifying parameters."""
|
| 1210 |
return {
|
| 1211 |
-
"
|
| 1212 |
-
"
|
| 1213 |
-
"
|
| 1214 |
-
"
|
| 1215 |
-
"
|
| 1216 |
-
"
|
| 1217 |
-
"thinking_budget": self.thinking_budget,
|
| 1218 |
-
"include_thoughts": self.include_thoughts,
|
| 1219 |
}
|
| 1220 |
-
|
| 1221 |
-
|
| 1222 |
-
|
| 1223 |
-
|
| 1224 |
-
|
| 1225 |
-
|
| 1226 |
-
|
| 1227 |
-
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
code to solve problems.
|
| 1234 |
-
"""
|
| 1235 |
-
|
| 1236 |
-
"""Override invoke to add code_execution parameter."""
|
| 1237 |
-
|
| 1238 |
-
if code_execution is not None:
|
| 1239 |
-
if not self._supports_code_execution:
|
| 1240 |
-
raise ValueError(
|
| 1241 |
-
f"Code execution is only supported on Gemini 1.5 Pro, \
|
| 1242 |
-
Gemini 1.5 Flash, "
|
| 1243 |
-
f"Gemini 2.0 Flash, and Gemini 2.0 Pro models. \
|
| 1244 |
-
Current model: {self.model}"
|
| 1245 |
-
)
|
| 1246 |
-
if "tools" not in kwargs:
|
| 1247 |
-
code_execution_tool = GoogleTool(code_execution=CodeExecution())
|
| 1248 |
-
kwargs["tools"] = [code_execution_tool]
|
| 1249 |
-
|
| 1250 |
-
else:
|
| 1251 |
-
raise ValueError(
|
| 1252 |
-
"Tools are already defined." "code_execution tool can't be defined"
|
| 1253 |
-
)
|
| 1254 |
-
|
| 1255 |
-
return super().invoke(input, config, stop=stop, **kwargs)
|
| 1256 |
-
|
| 1257 |
-
def _get_ls_params(
|
| 1258 |
-
self, stop: Optional[List[str]] = None, **kwargs: Any
|
| 1259 |
-
) -> LangSmithParams:
|
| 1260 |
-
"""Get standard params for tracing."""
|
| 1261 |
-
params = self._get_invocation_params(stop=stop, **kwargs)
|
| 1262 |
-
models_prefix = "models/"
|
| 1263 |
-
ls_model_name = (
|
| 1264 |
-
self.model[len(models_prefix) :]
|
| 1265 |
-
if self.model and self.model.startswith(models_prefix)
|
| 1266 |
-
else self.model
|
| 1267 |
)
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
| 1275 |
-
|
| 1276 |
-
if ls_stop := stop or params.get("stop", None):
|
| 1277 |
-
ls_params["ls_stop"] = ls_stop
|
| 1278 |
-
return ls_params
|
| 1279 |
-
|
| 1280 |
-
def _prepare_params(
|
| 1281 |
-
self,
|
| 1282 |
-
stop: Optional[List[str]],
|
| 1283 |
-
generation_config: Optional[Dict[str, Any]] = None,
|
| 1284 |
-
) -> GenerationConfig:
|
| 1285 |
-
gen_config = {
|
| 1286 |
-
k: v
|
| 1287 |
-
for k, v in {
|
| 1288 |
-
"candidate_count": self.n,
|
| 1289 |
-
"temperature": self.temperature,
|
| 1290 |
-
"stop_sequences": stop,
|
| 1291 |
-
"max_output_tokens": self.max_output_tokens,
|
| 1292 |
-
"top_k": self.top_k,
|
| 1293 |
-
"top_p": self.top_p,
|
| 1294 |
-
"response_modalities": self.response_modalities,
|
| 1295 |
-
"thinking_config": (
|
| 1296 |
-
(
|
| 1297 |
-
{"thinking_budget": self.thinking_budget}
|
| 1298 |
-
if self.thinking_budget is not None
|
| 1299 |
-
else {}
|
| 1300 |
-
)
|
| 1301 |
-
| (
|
| 1302 |
-
{"include_thoughts": self.include_thoughts}
|
| 1303 |
-
if self.include_thoughts is not None
|
| 1304 |
-
else {}
|
| 1305 |
-
)
|
| 1306 |
-
)
|
| 1307 |
-
if self.thinking_budget is not None or self.include_thoughts is not None
|
| 1308 |
-
else None,
|
| 1309 |
-
}.items()
|
| 1310 |
-
if v is not None
|
| 1311 |
}
|
| 1312 |
-
if generation_config:
|
| 1313 |
-
gen_config = {**gen_config, **generation_config}
|
| 1314 |
-
return GenerationConfig(**gen_config)
|
| 1315 |
-
|
| 1316 |
-
def _generate(
|
| 1317 |
-
self,
|
| 1318 |
-
messages: List[BaseMessage],
|
| 1319 |
-
stop: Optional[List[str]] = None,
|
| 1320 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 1321 |
-
*,
|
| 1322 |
-
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
| 1323 |
-
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
| 1324 |
-
safety_settings: Optional[SafetySettingDict] = None,
|
| 1325 |
-
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
| 1326 |
-
generation_config: Optional[Dict[str, Any]] = None,
|
| 1327 |
-
cached_content: Optional[str] = None,
|
| 1328 |
-
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
| 1329 |
-
**kwargs: Any,
|
| 1330 |
-
) -> ChatResult:
|
| 1331 |
-
request = self._prepare_request(
|
| 1332 |
-
messages,
|
| 1333 |
-
stop=stop,
|
| 1334 |
-
tools=tools,
|
| 1335 |
-
functions=functions,
|
| 1336 |
-
safety_settings=safety_settings,
|
| 1337 |
-
tool_config=tool_config,
|
| 1338 |
-
generation_config=generation_config,
|
| 1339 |
-
cached_content=cached_content or self.cached_content,
|
| 1340 |
-
tool_choice=tool_choice,
|
| 1341 |
-
)
|
| 1342 |
-
response: GenerateContentResponse = _chat_with_retry(
|
| 1343 |
-
request=request,
|
| 1344 |
-
**kwargs,
|
| 1345 |
-
generation_method=self.client.generate_content,
|
| 1346 |
-
metadata=self.default_metadata,
|
| 1347 |
-
)
|
| 1348 |
-
return _response_to_result(response)
|
| 1349 |
-
|
| 1350 |
-
async def _agenerate(
|
| 1351 |
-
self,
|
| 1352 |
-
messages: List[BaseMessage],
|
| 1353 |
-
stop: Optional[List[str]] = None,
|
| 1354 |
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
| 1355 |
-
*,
|
| 1356 |
-
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
| 1357 |
-
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
| 1358 |
-
safety_settings: Optional[SafetySettingDict] = None,
|
| 1359 |
-
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
| 1360 |
-
generation_config: Optional[Dict[str, Any]] = None,
|
| 1361 |
-
cached_content: Optional[str] = None,
|
| 1362 |
-
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
| 1363 |
-
**kwargs: Any,
|
| 1364 |
-
) -> ChatResult:
|
| 1365 |
-
if not self.async_client:
|
| 1366 |
-
updated_kwargs = {
|
| 1367 |
-
**kwargs,
|
| 1368 |
-
**{
|
| 1369 |
-
"tools": tools,
|
| 1370 |
-
"functions": functions,
|
| 1371 |
-
"safety_settings": safety_settings,
|
| 1372 |
-
"tool_config": tool_config,
|
| 1373 |
-
"generation_config": generation_config,
|
| 1374 |
-
},
|
| 1375 |
-
}
|
| 1376 |
-
return await super()._agenerate(
|
| 1377 |
-
messages, stop, run_manager, **updated_kwargs
|
| 1378 |
-
)
|
| 1379 |
-
|
| 1380 |
-
request = self._prepare_request(
|
| 1381 |
-
messages,
|
| 1382 |
-
stop=stop,
|
| 1383 |
-
tools=tools,
|
| 1384 |
-
functions=functions,
|
| 1385 |
-
safety_settings=safety_settings,
|
| 1386 |
-
tool_config=tool_config,
|
| 1387 |
-
generation_config=generation_config,
|
| 1388 |
-
cached_content=cached_content or self.cached_content,
|
| 1389 |
-
tool_choice=tool_choice,
|
| 1390 |
-
)
|
| 1391 |
-
response: GenerateContentResponse = await _achat_with_retry(
|
| 1392 |
-
request=request,
|
| 1393 |
-
**kwargs,
|
| 1394 |
-
generation_method=self.async_client.generate_content,
|
| 1395 |
-
metadata=self.default_metadata,
|
| 1396 |
-
)
|
| 1397 |
-
return _response_to_result(response)
|
| 1398 |
-
|
| 1399 |
-
def _stream(
|
| 1400 |
-
self,
|
| 1401 |
-
messages: List[BaseMessage],
|
| 1402 |
-
stop: Optional[List[str]] = None,
|
| 1403 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 1404 |
-
*,
|
| 1405 |
-
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
| 1406 |
-
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
| 1407 |
-
safety_settings: Optional[SafetySettingDict] = None,
|
| 1408 |
-
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
| 1409 |
-
generation_config: Optional[Dict[str, Any]] = None,
|
| 1410 |
-
cached_content: Optional[str] = None,
|
| 1411 |
-
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
| 1412 |
-
**kwargs: Any,
|
| 1413 |
-
) -> Iterator[ChatGenerationChunk]:
|
| 1414 |
-
request = self._prepare_request(
|
| 1415 |
-
messages,
|
| 1416 |
-
stop=stop,
|
| 1417 |
-
tools=tools,
|
| 1418 |
-
functions=functions,
|
| 1419 |
-
safety_settings=safety_settings,
|
| 1420 |
-
tool_config=tool_config,
|
| 1421 |
-
generation_config=generation_config,
|
| 1422 |
-
cached_content=cached_content or self.cached_content,
|
| 1423 |
-
tool_choice=tool_choice,
|
| 1424 |
-
)
|
| 1425 |
-
response: GenerateContentResponse = _chat_with_retry(
|
| 1426 |
-
request=request,
|
| 1427 |
-
generation_method=self.client.stream_generate_content,
|
| 1428 |
-
**kwargs,
|
| 1429 |
-
metadata=self.default_metadata,
|
| 1430 |
-
)
|
| 1431 |
-
|
| 1432 |
-
prev_usage_metadata: UsageMetadata | None = None
|
| 1433 |
-
for chunk in response:
|
| 1434 |
-
_chat_result = _response_to_result(
|
| 1435 |
-
chunk, stream=True, prev_usage=prev_usage_metadata
|
| 1436 |
-
)
|
| 1437 |
-
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
| 1438 |
-
message = cast(AIMessageChunk, gen.message)
|
| 1439 |
|
| 1440 |
-
|
| 1441 |
-
|
| 1442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1443 |
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1454 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1455 |
)
|
| 1456 |
-
|
| 1457 |
-
if run_manager:
|
| 1458 |
-
run_manager.on_llm_new_token(gen.text)
|
| 1459 |
-
yield gen
|
| 1460 |
-
|
| 1461 |
-
async def _astream(
|
| 1462 |
-
self,
|
| 1463 |
-
messages: List[BaseMessage],
|
| 1464 |
-
stop: Optional[List[str]] = None,
|
| 1465 |
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
| 1466 |
-
*,
|
| 1467 |
-
tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
|
| 1468 |
-
functions: Optional[Sequence[_FunctionDeclarationType]] = None,
|
| 1469 |
-
safety_settings: Optional[SafetySettingDict] = None,
|
| 1470 |
-
tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
|
| 1471 |
-
generation_config: Optional[Dict[str, Any]] = None,
|
| 1472 |
-
cached_content: Optional[str] = None,
|
| 1473 |
-
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
|
| 1474 |
-
**kwargs: Any,
|
| 1475 |
-
) -> AsyncIterator[ChatGenerationChunk]:
|
| 1476 |
-
if not self.async_client:
|
| 1477 |
-
updated_kwargs = {
|
| 1478 |
-
**kwargs,
|
| 1479 |
-
**{
|
| 1480 |
-
"tools": tools,
|
| 1481 |
-
"functions": functions,
|
| 1482 |
-
"safety_settings": safety_settings,
|
| 1483 |
-
"tool_config": tool_config,
|
| 1484 |
-
"generation_config": generation_config,
|
| 1485 |
-
},
|
| 1486 |
-
}
|
| 1487 |
-
async for value in super()._astream(
|
| 1488 |
-
messages, stop, run_manager, **updated_kwargs
|
| 1489 |
-
):
|
| 1490 |
-
yield value
|
| 1491 |
else:
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
|
| 1498 |
-
|
| 1499 |
-
|
| 1500 |
-
|
| 1501 |
-
|
| 1502 |
-
|
| 1503 |
-
|
| 1504 |
-
|
| 1505 |
-
request=request,
|
| 1506 |
-
generation_method=self.async_client.stream_generate_content,
|
| 1507 |
-
**kwargs,
|
| 1508 |
-
metadata=self.default_metadata,
|
| 1509 |
-
):
|
| 1510 |
-
_chat_result = _response_to_result(
|
| 1511 |
-
chunk, stream=True, prev_usage=prev_usage_metadata
|
| 1512 |
)
|
| 1513 |
-
|
| 1514 |
-
|
| 1515 |
-
|
| 1516 |
-
|
| 1517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1518 |
)
|
| 1519 |
-
|
| 1520 |
-
|
| 1521 |
-
|
| 1522 |
-
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1531 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1532 |
|
| 1533 |
-
|
| 1534 |
-
|
| 1535 |
-
|
| 1536 |
-
|
| 1537 |
-
|
| 1538 |
-
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
|
| 1546 |
-
|
| 1547 |
-
generation_config: Optional[Dict[str, Any]] = None,
|
| 1548 |
-
cached_content: Optional[str] = None,
|
| 1549 |
-
) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
|
| 1550 |
-
if tool_choice and tool_config:
|
| 1551 |
-
raise ValueError(
|
| 1552 |
-
"Must specify at most one of tool_choice and tool_config, received "
|
| 1553 |
-
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
|
| 1554 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1555 |
|
| 1556 |
-
|
| 1557 |
-
|
| 1558 |
-
|
| 1559 |
-
formatted_tools = tools
|
| 1560 |
-
elif tools:
|
| 1561 |
-
formatted_tools = [convert_to_genai_function_declarations(tools)]
|
| 1562 |
-
elif functions:
|
| 1563 |
-
formatted_tools = [convert_to_genai_function_declarations(functions)]
|
| 1564 |
-
|
| 1565 |
-
filtered_messages = []
|
| 1566 |
-
for message in messages:
|
| 1567 |
-
if isinstance(message, HumanMessage) and not message.content:
|
| 1568 |
-
warnings.warn(
|
| 1569 |
-
"HumanMessage with empty content was removed to prevent API error"
|
| 1570 |
-
)
|
| 1571 |
-
else:
|
| 1572 |
-
filtered_messages.append(message)
|
| 1573 |
-
messages = filtered_messages
|
| 1574 |
-
|
| 1575 |
-
system_instruction, history = _parse_chat_history(
|
| 1576 |
-
messages,
|
| 1577 |
-
convert_system_message_to_human=self.convert_system_message_to_human,
|
| 1578 |
-
)
|
| 1579 |
-
if tool_choice:
|
| 1580 |
-
if not formatted_tools:
|
| 1581 |
-
msg = (
|
| 1582 |
-
f"Received {tool_choice=} but no {tools=}. 'tool_choice' can only "
|
| 1583 |
-
f"be specified if 'tools' is specified."
|
| 1584 |
-
)
|
| 1585 |
-
raise ValueError(msg)
|
| 1586 |
-
all_names: List[str] = []
|
| 1587 |
-
for t in formatted_tools:
|
| 1588 |
-
if hasattr(t, "function_declarations"):
|
| 1589 |
-
t_with_declarations = cast(Any, t)
|
| 1590 |
-
all_names.extend(
|
| 1591 |
-
f.name for f in t_with_declarations.function_declarations
|
| 1592 |
-
)
|
| 1593 |
-
elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"):
|
| 1594 |
-
continue
|
| 1595 |
-
else:
|
| 1596 |
-
raise TypeError(
|
| 1597 |
-
f"Tool {t} doesn't have function_declarations attribute"
|
| 1598 |
-
)
|
| 1599 |
-
|
| 1600 |
-
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
|
| 1601 |
-
|
| 1602 |
-
formatted_tool_config = None
|
| 1603 |
-
if tool_config:
|
| 1604 |
-
formatted_tool_config = ToolConfig(
|
| 1605 |
-
function_calling_config=tool_config["function_calling_config"]
|
| 1606 |
-
)
|
| 1607 |
-
formatted_safety_settings = []
|
| 1608 |
-
if safety_settings:
|
| 1609 |
-
formatted_safety_settings = [
|
| 1610 |
-
SafetySetting(category=c, threshold=t)
|
| 1611 |
-
for c, t in safety_settings.items()
|
| 1612 |
-
]
|
| 1613 |
-
request = GenerateContentRequest(
|
| 1614 |
-
model=self.model,
|
| 1615 |
-
contents=history,
|
| 1616 |
-
tools=formatted_tools,
|
| 1617 |
-
tool_config=formatted_tool_config,
|
| 1618 |
-
safety_settings=formatted_safety_settings,
|
| 1619 |
-
generation_config=self._prepare_params(
|
| 1620 |
-
stop, generation_config=generation_config
|
| 1621 |
-
),
|
| 1622 |
-
cached_content=cached_content,
|
| 1623 |
-
)
|
| 1624 |
-
if system_instruction:
|
| 1625 |
-
request.system_instruction = system_instruction
|
| 1626 |
|
| 1627 |
-
|
|
|
|
| 1628 |
|
| 1629 |
-
|
| 1630 |
-
|
|
|
|
|
|
|
| 1631 |
|
| 1632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1633 |
|
| 1634 |
-
|
| 1635 |
-
text: The string input to tokenize.
|
| 1636 |
|
| 1637 |
-
|
| 1638 |
-
The integer number of tokens in the text.
|
| 1639 |
-
"""
|
| 1640 |
-
result = self.client.count_tokens(
|
| 1641 |
-
model=self.model, contents=[Content(parts=[Part(text=text)])]
|
| 1642 |
-
)
|
| 1643 |
-
return result.total_tokens
|
| 1644 |
|
| 1645 |
-
|
| 1646 |
-
|
| 1647 |
-
|
| 1648 |
-
|
| 1649 |
-
|
| 1650 |
-
|
| 1651 |
-
|
| 1652 |
-
_ = kwargs.pop("method", None)
|
| 1653 |
-
_ = kwargs.pop("strict", None)
|
| 1654 |
-
if kwargs:
|
| 1655 |
-
raise ValueError(f"Received unsupported arguments {kwargs}")
|
| 1656 |
-
tool_name = _get_tool_name(schema) # type: ignore[arg-type]
|
| 1657 |
-
if isinstance(schema, type) and is_basemodel_subclass_safe(schema):
|
| 1658 |
-
parser: OutputParserLike = PydanticToolsParser(
|
| 1659 |
-
tools=[schema], first_tool_only=True
|
| 1660 |
-
)
|
| 1661 |
else:
|
| 1662 |
-
|
| 1663 |
-
|
|
|
|
|
|
|
| 1664 |
try:
|
| 1665 |
-
|
| 1666 |
-
[
|
| 1667 |
-
|
| 1668 |
-
|
| 1669 |
-
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
|
| 1675 |
-
|
| 1676 |
-
|
| 1677 |
-
|
| 1678 |
-
|
| 1679 |
-
|
| 1680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1681 |
)
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
-
|
| 1686 |
-
|
| 1687 |
-
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
|
| 1699 |
-
|
| 1700 |
-
|
| 1701 |
-
|
| 1702 |
-
|
| 1703 |
-
|
| 1704 |
-
|
| 1705 |
-
**kwargs: Any additional parameters to pass to the
|
| 1706 |
-
:class:`~langchain.runnable.Runnable` constructor.
|
| 1707 |
-
"""
|
| 1708 |
-
if tool_choice and tool_config:
|
| 1709 |
-
raise ValueError(
|
| 1710 |
-
"Must specify at most one of tool_choice and tool_config, received "
|
| 1711 |
-
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
|
| 1712 |
)
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
|
| 1721 |
-
|
| 1722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1723 |
else:
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
|
| 1730 |
-
|
| 1731 |
-
|
| 1732 |
-
|
|
|
|
|
|
|
|
|
|
| 1733 |
)
|
| 1734 |
-
|
| 1735 |
-
|
| 1736 |
-
|
| 1737 |
-
|
| 1738 |
-
|
| 1739 |
-
|
| 1740 |
-
|
| 1741 |
-
|
| 1742 |
-
|
| 1743 |
-
if is_typeddict(tool):
|
| 1744 |
-
return convert_to_openai_tool(cast(Dict, tool))["function"]["name"]
|
| 1745 |
-
else:
|
| 1746 |
-
raise e
|
|
|
|
| 1 |
+
SYSTEM_PROMPT = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
1. Air Quality Data (df):
|
| 4 |
+
- Columns: 'Timestamp', 'station', 'PM2.5', 'PM10', 'address', 'city', 'latitude', 'longitude', 'state'
|
| 5 |
+
- Example row: ['2023-01-01', 'StationA', 45.67, 78.9, '123 Main St', 'Mumbai', 19.07, 72.87, 'Maharashtra']
|
| 6 |
+
- Frequency: daily
|
| 7 |
+
- 'pollution' generally means 'PM2.5'.
|
| 8 |
+
- PM2.5 guidelines: India: 60, WHO: 15. PM10 guidelines: India: 100, WHO: 50.
|
| 9 |
|
| 10 |
+
2. NCAP Funding Data (ncap_data):
|
| 11 |
+
- Columns: 'city', 'state', 'funding_received', 'year', 'project', 'status'
|
| 12 |
+
- Example row: ['Mumbai', 'Maharashtra', 10000000, 2022, 'Clean Air Project', 'Ongoing']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
3. State Population Data (states_data):
|
| 15 |
+
- Columns: 'state', 'population', 'year', 'urban_population', 'rural_population'
|
| 16 |
+
- Example row: ['Maharashtra', 123000000, 2021, 60000000, 63000000]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
You already have these dataframes loaded as df, ncap_data, and states_data. Do not read any files. Use these dataframes to answer questions about air quality, funding, or population. When aggregating, report standard deviation, standard error, and number of data points. Always report units. If a plot is required, follow the previous instructions for saving and reporting plots. If a question is about funding or population, use the relevant dataframe.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
|
|
|
| 20 |
|
| 21 |
+
import os
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from pandasai import Agent, SmartDataframe
|
| 24 |
+
from typing import Tuple
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from pandasai.llm import HuggingFaceTextGen
|
| 27 |
+
from dotenv import load_dotenv
|
| 28 |
+
from langchain_groq import ChatGroq
|
| 29 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 30 |
+
import matplotlib.pyplot as plt
|
| 31 |
+
import json
|
| 32 |
+
from datetime import datetime
|
| 33 |
+
from dotenv import load_dotenv
|
| 34 |
+
|
| 35 |
+
# FORCE reload environment variables
|
| 36 |
+
load_dotenv(override=True)
|
| 37 |
+
Groq_Token = os.getenv("GROQ_API_KEY")
|
| 38 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 39 |
+
gemini_token = os.getenv("GEMINI_TOKEN")
|
| 40 |
+
import uuid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# FORCE reload environment variables
|
| 43 |
+
|
| 44 |
+
models = {
|
| 45 |
+
"gpt-oss-20b": "openai/gpt-oss-20b",
|
| 46 |
+
"gpt-oss-120b": "openai/gpt-oss-120b",
|
| 47 |
+
"llama3.1": "llama-3.1-8b-instant",
|
| 48 |
+
"llama3.3": "llama-3.3-70b-versatile",
|
| 49 |
+
"deepseek-R1": "deepseek-r1-distill-llama-70b",
|
| 50 |
+
"llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 51 |
+
"llama4 scout":"meta-llama/llama-4-scout-17b-16e-instruct",
|
| 52 |
+
"gemini-pro": "gemini-1.5-pro"
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def log_interaction(user_query, model_name, response_content, generated_code, execution_time, error_message=None, is_image=False):
|
| 56 |
+
"""Log user interactions to Hugging Face dataset"""
|
| 57 |
try:
|
| 58 |
+
if not hf_token or hf_token.strip() == "":
|
| 59 |
+
print("Warning: HF_TOKEN not available, skipping logging")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
# Create log entry
|
| 63 |
+
log_entry = {
|
| 64 |
+
"timestamp": datetime.now().isoformat(),
|
| 65 |
+
"session_id": str(uuid.uuid4()),
|
| 66 |
+
"user_query": user_query,
|
| 67 |
+
"model_name": model_name,
|
| 68 |
+
"response_content": str(response_content),
|
| 69 |
+
"generated_code": generated_code or "",
|
| 70 |
+
"execution_time_seconds": execution_time,
|
| 71 |
+
"error_message": error_message or "",
|
| 72 |
+
"is_image_output": is_image,
|
| 73 |
+
"success": error_message is None
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Create DataFrame
|
| 77 |
+
df = pd.DataFrame([log_entry])
|
| 78 |
+
|
| 79 |
+
# Create unique filename with timestamp
|
| 80 |
+
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 81 |
+
random_id = str(uuid.uuid4())[:8]
|
| 82 |
+
filename = f"interaction_log_{timestamp_str}_{random_id}.parquet"
|
| 83 |
+
|
| 84 |
+
# Save locally first
|
| 85 |
+
local_path = f"/tmp/{filename}"
|
| 86 |
+
df.to_parquet(local_path, index=False)
|
| 87 |
+
# Clean up local file
|
| 88 |
+
if os.path.exists(local_path):
|
| 89 |
+
os.remove(local_path)
|
| 90 |
+
print(f"Successfully logged interaction locally: {filename}")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"Error logging interaction: {e}")
|
| 93 |
+
|
| 94 |
+
def preprocess_and_load_df(path: str) -> pd.DataFrame:
|
| 95 |
+
"""Load and preprocess the dataframe"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
try:
|
| 97 |
+
df = pd.read_csv(path)
|
| 98 |
+
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
| 99 |
+
return df
|
| 100 |
+
except Exception as e:
|
| 101 |
+
raise Exception(f"Error loading dataframe: {e}")
|
| 102 |
+
|
| 103 |
+
def load_smart_df(df: pd.DataFrame, inference_server: str, name="mistral") -> SmartDataframe:
|
| 104 |
+
"""Load smart dataframe with error handling"""
|
| 105 |
+
try:
|
| 106 |
+
if name == "gemini-pro":
|
| 107 |
+
if not gemini_token or gemini_token.strip() == "":
|
| 108 |
+
raise ValueError("Gemini API token not available or empty")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
llm = ChatGoogleGenerativeAI(
|
| 110 |
+
model=models[name],
|
| 111 |
+
google_api_key=gemini_token,
|
| 112 |
+
temperature=0.1
|
| 113 |
)
|
| 114 |
+
else:
|
| 115 |
+
if not Groq_Token or Groq_Token.strip() == "":
|
| 116 |
+
raise ValueError("Groq API token not available or empty")
|
| 117 |
+
llm = ChatGroq(
|
| 118 |
+
model=models[name],
|
| 119 |
+
api_key=Groq_Token,
|
| 120 |
+
temperature=0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
+
smart_df = SmartDataframe(df, config={"llm": llm, "max_retries": 5, "enable_cache": False})
|
| 123 |
+
return smart_df
|
| 124 |
+
except Exception as e:
|
| 125 |
+
raise Exception(f"Error loading smart dataframe: {e}")
|
| 126 |
+
try:
|
| 127 |
+
response = agent.chat(prompt)
|
| 128 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 129 |
+
|
| 130 |
+
gen_code = getattr(agent, 'last_code_generated', '')
|
| 131 |
+
ex_code = getattr(agent, 'last_code_executed', '')
|
| 132 |
+
last_prompt = getattr(agent, 'last_prompt', prompt)
|
| 133 |
+
|
| 134 |
+
# Log the interaction
|
| 135 |
+
log_interaction(
|
| 136 |
+
user_query=prompt,
|
| 137 |
+
model_name="pandas_ai_agent",
|
| 138 |
+
response_content=response,
|
| 139 |
+
generated_code=gen_code,
|
| 140 |
+
execution_time=execution_time,
|
| 141 |
+
error_message=None,
|
| 142 |
+
is_image=isinstance(response, str) and any(response.endswith(ext) for ext in ['.png', '.jpg', '.jpeg'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
)
|
| 144 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
return {
|
| 146 |
+
"role": "assistant",
|
| 147 |
+
"content": response,
|
| 148 |
+
"gen_code": gen_code,
|
| 149 |
+
"ex_code": ex_code,
|
| 150 |
+
"last_prompt": last_prompt,
|
| 151 |
+
"error": None
|
|
|
|
|
|
|
| 152 |
}
|
| 153 |
+
except Exception as e:
|
| 154 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 155 |
+
error_msg = str(e)
|
| 156 |
+
|
| 157 |
+
# Log the failed interaction
|
| 158 |
+
log_interaction(
|
| 159 |
+
user_query=prompt,
|
| 160 |
+
model_name="pandas_ai_agent",
|
| 161 |
+
response_content=f"Error: {error_msg}",
|
| 162 |
+
generated_code="",
|
| 163 |
+
execution_time=execution_time,
|
| 164 |
+
error_message=error_msg,
|
| 165 |
+
is_image=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
)
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"role": "assistant",
|
| 170 |
+
"content": f"Error: {error_msg}",
|
| 171 |
+
"gen_code": "",
|
| 172 |
+
"ex_code": "",
|
| 173 |
+
"last_prompt": prompt,
|
| 174 |
+
"error": error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
def decorate_with_code(response: dict) -> str:
|
| 178 |
+
"""Decorate response with code details"""
|
| 179 |
+
gen_code = response.get("gen_code", "No code generated")
|
| 180 |
+
last_prompt = response.get("last_prompt", "No prompt")
|
| 181 |
+
|
| 182 |
+
return f"""<details>
|
| 183 |
+
<summary>Generated Code</summary>
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
{gen_code}
|
| 187 |
+
```
|
| 188 |
+
</details>
|
| 189 |
+
|
| 190 |
+
<details>
|
| 191 |
+
<summary>Prompt</summary>
|
| 192 |
+
|
| 193 |
+
{last_prompt}
|
| 194 |
+
"""
|
| 195 |
|
| 196 |
+
def show_response(st, response):
|
| 197 |
+
"""Display response with error handling"""
|
| 198 |
+
try:
|
| 199 |
+
with st.chat_message(response["role"]):
|
| 200 |
+
content = response.get("content", "No content")
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
# Try to open as image
|
| 204 |
+
image = Image.open(content)
|
| 205 |
+
if response.get("gen_code"):
|
| 206 |
+
st.markdown(decorate_with_code(response), unsafe_allow_html=True)
|
| 207 |
+
st.image(image)
|
| 208 |
+
return {"is_image": True}
|
| 209 |
+
except:
|
| 210 |
+
# Not an image, display as text
|
| 211 |
+
if response.get("gen_code"):
|
| 212 |
+
display_content = decorate_with_code(response) + f"""</details>
|
| 213 |
+
|
| 214 |
+
{content}"""
|
| 215 |
+
else:
|
| 216 |
+
display_content = content
|
| 217 |
+
st.markdown(display_content, unsafe_allow_html=True)
|
| 218 |
+
return {"is_image": False}
|
| 219 |
+
except Exception as e:
|
| 220 |
+
st.error(f"Error displaying response: {e}")
|
| 221 |
+
return {"is_image": False}
|
| 222 |
+
|
| 223 |
+
def ask_question(model_name, question):
|
| 224 |
+
"""Ask question with comprehensive error handling and logging"""
|
| 225 |
+
start_time = datetime.now()
|
| 226 |
+
try:
|
| 227 |
+
# Reload environment variables to get fresh values
|
| 228 |
+
load_dotenv(override=True)
|
| 229 |
+
fresh_groq_token = os.getenv("GROQ_API_KEY")
|
| 230 |
+
fresh_gemini_token = os.getenv("GEMINI_TOKEN")
|
| 231 |
+
|
| 232 |
+
print(f"ask_question - Fresh Groq Token: {'Present' if fresh_groq_token else 'Missing'}")
|
| 233 |
+
|
| 234 |
+
# Check API availability with fresh tokens
|
| 235 |
+
if model_name == "gemini-pro":
|
| 236 |
+
if not fresh_gemini_token or fresh_gemini_token.strip() == "":
|
| 237 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 238 |
+
error_msg = "Missing or empty API token"
|
| 239 |
+
|
| 240 |
+
# Log the failed interaction
|
| 241 |
+
log_interaction(
|
| 242 |
+
user_query=question,
|
| 243 |
+
model_name=model_name,
|
| 244 |
+
response_content="❌ Gemini API token not available or empty",
|
| 245 |
+
generated_code="",
|
| 246 |
+
execution_time=execution_time,
|
| 247 |
+
error_message=error_msg,
|
| 248 |
+
is_image=False
|
| 249 |
)
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"role": "assistant",
|
| 253 |
+
"content": "❌ Gemini API token not available or empty. Please set GEMINI_TOKEN in your environment variables.",
|
| 254 |
+
"gen_code": "",
|
| 255 |
+
"ex_code": "",
|
| 256 |
+
"last_prompt": question,
|
| 257 |
+
"error": error_msg
|
| 258 |
+
}
|
| 259 |
+
llm = ChatGoogleGenerativeAI(
|
| 260 |
+
model=models[model_name],
|
| 261 |
+
google_api_key=fresh_gemini_token,
|
| 262 |
+
temperature=0
|
| 263 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
else:
|
| 265 |
+
if not fresh_groq_token or fresh_groq_token.strip() == "":
|
| 266 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 267 |
+
error_msg = "Missing or empty API token"
|
| 268 |
+
|
| 269 |
+
# Log the failed interaction
|
| 270 |
+
log_interaction(
|
| 271 |
+
user_query=question,
|
| 272 |
+
model_name=model_name,
|
| 273 |
+
response_content="❌ Groq API token not available or empty",
|
| 274 |
+
generated_code="",
|
| 275 |
+
execution_time=execution_time,
|
| 276 |
+
error_message=error_msg,
|
| 277 |
+
is_image=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
)
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"role": "assistant",
|
| 282 |
+
"content": "❌ Groq API token not available or empty. Please set GROQ_API_KEY in your environment variables and restart the application.",
|
| 283 |
+
"gen_code": "",
|
| 284 |
+
"ex_code": "",
|
| 285 |
+
"last_prompt": question,
|
| 286 |
+
"error": error_msg
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
# Test the API key by trying to create the client
|
| 290 |
+
try:
|
| 291 |
+
llm = ChatGroq(
|
| 292 |
+
model=models[model_name],
|
| 293 |
+
api_key=fresh_groq_token,
|
| 294 |
+
temperature=0.1
|
| 295 |
)
|
| 296 |
+
# Test with a simple call to verify the API key works
|
| 297 |
+
test_response = llm.invoke("Test")
|
| 298 |
+
print("API key test successful")
|
| 299 |
+
except Exception as api_error:
|
| 300 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 301 |
+
error_msg = str(api_error)
|
| 302 |
+
|
| 303 |
+
if "organization_restricted" in error_msg.lower() or "unauthorized" in error_msg.lower():
|
| 304 |
+
response_content = "❌ API Key Error: Your Groq API key appears to be invalid, expired, or restricted. Please check your API key in the .env file."
|
| 305 |
+
log_error_msg = f"API key validation failed: {error_msg}"
|
| 306 |
+
else:
|
| 307 |
+
response_content = f"❌ API Connection Error: {error_msg}"
|
| 308 |
+
log_error_msg = error_msg
|
| 309 |
+
|
| 310 |
+
# Log the failed interaction
|
| 311 |
+
log_interaction(
|
| 312 |
+
user_query=question,
|
| 313 |
+
model_name=model_name,
|
| 314 |
+
response_content=response_content,
|
| 315 |
+
generated_code="",
|
| 316 |
+
execution_time=execution_time,
|
| 317 |
+
error_message=log_error_msg,
|
| 318 |
+
is_image=False
|
| 319 |
)
|
| 320 |
+
|
| 321 |
+
return {
|
| 322 |
+
"role": "assistant",
|
| 323 |
+
"content": response_content,
|
| 324 |
+
"gen_code": "",
|
| 325 |
+
"ex_code": "",
|
| 326 |
+
"last_prompt": question,
|
| 327 |
+
"error": log_error_msg
|
| 328 |
+
}
|
| 329 |
|
| 330 |
+
# Check if data file exists
|
| 331 |
+
if not os.path.exists("Data.csv"):
|
| 332 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 333 |
+
error_msg = "Data file not found"
|
| 334 |
+
|
| 335 |
+
# Log the failed interaction
|
| 336 |
+
log_interaction(
|
| 337 |
+
user_query=question,
|
| 338 |
+
model_name=model_name,
|
| 339 |
+
response_content="❌ Data.csv file not found",
|
| 340 |
+
generated_code="",
|
| 341 |
+
execution_time=execution_time,
|
| 342 |
+
error_message=error_msg,
|
| 343 |
+
is_image=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
+
|
| 346 |
+
return {
|
| 347 |
+
"role": "assistant",
|
| 348 |
+
"content": "❌ Data.csv file not found. Please ensure the data file is in the correct location.",
|
| 349 |
+
"gen_code": "",
|
| 350 |
+
"ex_code": "",
|
| 351 |
+
"last_prompt": question,
|
| 352 |
+
"error": error_msg
|
| 353 |
+
}
|
| 354 |
|
| 355 |
+
df_check = pd.read_csv("Data.csv")
|
| 356 |
+
df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
|
| 357 |
+
df_check = df_check.head(5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
+
new_line = "\n"
|
| 360 |
+
parameters = {"font.size": 12, "figure.dpi": 600}
|
| 361 |
|
| 362 |
+
template = f"""```python
|
| 363 |
+
import pandas as pd
|
| 364 |
+
import matplotlib.pyplot as plt
|
| 365 |
+
import uuid
|
| 366 |
|
| 367 |
+
plt.rcParams.update({parameters})
|
| 368 |
+
|
| 369 |
+
df = pd.read_csv("Data.csv")
|
| 370 |
+
df["Timestamp"] = pd.to_datetime(df["Timestamp"])
|
| 371 |
+
|
| 372 |
+
# Available columns and data types:
|
| 373 |
+
{new_line.join(map(lambda x: '# '+x, str(df_check.dtypes).split(new_line)))}
|
| 374 |
+
|
| 375 |
+
# Question: {question.strip()}
|
| 376 |
+
# Generate code to answer the question and save result in 'answer' variable
|
| 377 |
+
# If creating a plot, save it with a unique filename and store the filename in 'answer'
|
| 378 |
+
# If returning text/numbers, store the result directly in 'answer'
|
| 379 |
+
```"""
|
| 380 |
+
|
| 381 |
+
system_prompt = """You are a helpful assistant that generates Python code for data analysis.
|
| 382 |
+
|
| 383 |
+
Rules:
|
| 384 |
+
1. Always save your final result in a variable called 'answer'
|
| 385 |
+
2. If creating a plot, save it with plt.savefig() and store the filename in 'answer'
|
| 386 |
+
3. If returning text/numbers, store the result directly in 'answer'
|
| 387 |
+
4. Use descriptive variable names and add comments
|
| 388 |
+
5. Handle potential errors gracefully
|
| 389 |
+
6. For plots, use unique filenames to avoid conflicts
|
| 390 |
+
"""
|
| 391 |
|
| 392 |
+
query = f"""{system_prompt}
|
|
|
|
| 393 |
|
| 394 |
+
Complete the following code to answer the user's question:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
+
{template}
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
# Make API call
|
| 400 |
+
if model_name == "gemini-pro":
|
| 401 |
+
response = llm.invoke(query)
|
| 402 |
+
answer = response.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
else:
|
| 404 |
+
response = llm.invoke(query)
|
| 405 |
+
answer = response.content
|
| 406 |
+
|
| 407 |
+
# Extract and execute code
|
| 408 |
try:
|
| 409 |
+
if "```python" in answer:
|
| 410 |
+
code_part = answer.split("```python")[1].split("```")[0]
|
| 411 |
+
else:
|
| 412 |
+
code_part = answer
|
| 413 |
+
|
| 414 |
+
full_code = f"""
|
| 415 |
+
{template.split("```python")[1].split("```")[0]}
|
| 416 |
+
{code_part}
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
# Execute code in a controlled environment
|
| 420 |
+
local_vars = {}
|
| 421 |
+
global_vars = {
|
| 422 |
+
'pd': pd,
|
| 423 |
+
'plt': plt,
|
| 424 |
+
'os': os,
|
| 425 |
+
'uuid': __import__('uuid')
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
exec(full_code, global_vars, local_vars)
|
| 429 |
+
|
| 430 |
+
# Get the answer
|
| 431 |
+
if 'answer' in local_vars:
|
| 432 |
+
answer_result = local_vars['answer']
|
| 433 |
+
else:
|
| 434 |
+
answer_result = "No answer variable found in generated code"
|
| 435 |
+
|
| 436 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 437 |
+
|
| 438 |
+
# Determine if output is an image
|
| 439 |
+
is_image = isinstance(answer_result, str) and any(answer_result.endswith(ext) for ext in ['.png', '.jpg', '.jpeg'])
|
| 440 |
+
|
| 441 |
+
# Log successful interaction
|
| 442 |
+
log_interaction(
|
| 443 |
+
user_query=question,
|
| 444 |
+
model_name=model_name,
|
| 445 |
+
response_content=str(answer_result),
|
| 446 |
+
generated_code=full_code,
|
| 447 |
+
execution_time=execution_time,
|
| 448 |
+
error_message=None,
|
| 449 |
+
is_image=is_image
|
| 450 |
)
|
| 451 |
+
|
| 452 |
+
return {
|
| 453 |
+
"role": "assistant",
|
| 454 |
+
"content": answer_result,
|
| 455 |
+
"gen_code": full_code,
|
| 456 |
+
"ex_code": full_code,
|
| 457 |
+
"last_prompt": question,
|
| 458 |
+
"error": None
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
except Exception as code_error:
|
| 462 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 463 |
+
error_msg = str(code_error)
|
| 464 |
+
|
| 465 |
+
# Log the failed code execution
|
| 466 |
+
log_interaction(
|
| 467 |
+
user_query=question,
|
| 468 |
+
model_name=model_name,
|
| 469 |
+
response_content=f"❌ Error executing generated code: {error_msg}",
|
| 470 |
+
generated_code=full_code if 'full_code' in locals() else "",
|
| 471 |
+
execution_time=execution_time,
|
| 472 |
+
error_message=error_msg,
|
| 473 |
+
is_image=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
)
|
| 475 |
+
|
| 476 |
+
return {
|
| 477 |
+
"role": "assistant",
|
| 478 |
+
"content": f"❌ Error executing generated code: {error_msg}",
|
| 479 |
+
"gen_code": full_code if 'full_code' in locals() else "",
|
| 480 |
+
"ex_code": full_code if 'full_code' in locals() else "",
|
| 481 |
+
"last_prompt": question,
|
| 482 |
+
"error": error_msg
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
except Exception as e:
|
| 486 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 487 |
+
error_msg = str(e)
|
| 488 |
+
|
| 489 |
+
# Handle specific API errors
|
| 490 |
+
if "organization_restricted" in error_msg:
|
| 491 |
+
response_content = "❌ API Organization Restricted: Your API key access has been restricted. Please check your Groq API key or try generating a new one."
|
| 492 |
+
log_error_msg = "API access restricted"
|
| 493 |
+
elif "rate_limit" in error_msg.lower():
|
| 494 |
+
response_content = "❌ Rate limit exceeded. Please wait a moment and try again."
|
| 495 |
+
log_error_msg = "Rate limit exceeded"
|
| 496 |
else:
|
| 497 |
+
response_content = f"❌ Error: {error_msg}"
|
| 498 |
+
log_error_msg = error_msg
|
| 499 |
+
|
| 500 |
+
# Log the failed interaction
|
| 501 |
+
log_interaction(
|
| 502 |
+
user_query=question,
|
| 503 |
+
model_name=model_name,
|
| 504 |
+
response_content=response_content,
|
| 505 |
+
generated_code="",
|
| 506 |
+
execution_time=execution_time,
|
| 507 |
+
error_message=log_error_msg,
|
| 508 |
+
is_image=False
|
| 509 |
)
|
| 510 |
+
|
| 511 |
+
return {
|
| 512 |
+
"role": "assistant",
|
| 513 |
+
"content": response_content,
|
| 514 |
+
"gen_code": "",
|
| 515 |
+
"ex_code": "",
|
| 516 |
+
"last_prompt": question,
|
| 517 |
+
"error": log_error_msg
|
| 518 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|