AbhayVG commited on
Commit
a1e5ea7
·
verified ·
1 Parent(s): bf3ae6d
Files changed (1) hide show
  1. src.py +478 -1706
src.py CHANGED
@@ -1,1746 +1,518 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import base64
5
- import json
6
- import logging
7
- import mimetypes
8
- import uuid
9
- import warnings
10
- from difflib import get_close_matches
11
- from operator import itemgetter
12
- from typing import (
13
- Any,
14
- AsyncIterator,
15
- Callable,
16
- Dict,
17
- Iterator,
18
- List,
19
- Mapping,
20
- Optional,
21
- Sequence,
22
- Tuple,
23
- Type,
24
- Union,
25
- cast,
26
- )
27
-
28
- import filetype # type: ignore[import]
29
- import google.api_core
30
-
31
- # TODO: remove ignore once the google package is published with types
32
- import proto # type: ignore[import]
33
- from google.ai.generativelanguage_v1beta import (
34
- GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
35
- )
36
- from google.ai.generativelanguage_v1beta.types import (
37
- Blob,
38
- Candidate,
39
- CodeExecution,
40
- Content,
41
- FileData,
42
- FunctionCall,
43
- FunctionDeclaration,
44
- FunctionResponse,
45
- GenerateContentRequest,
46
- GenerateContentResponse,
47
- GenerationConfig,
48
- Part,
49
- SafetySetting,
50
- ToolConfig,
51
- VideoMetadata,
52
- )
53
- from google.ai.generativelanguage_v1beta.types import Tool as GoogleTool
54
- from langchain_core.callbacks.manager import (
55
- AsyncCallbackManagerForLLMRun,
56
- CallbackManagerForLLMRun,
57
- )
58
- from langchain_core.language_models import LanguageModelInput
59
- from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
60
- from langchain_core.messages import (
61
- AIMessage,
62
- AIMessageChunk,
63
- BaseMessage,
64
- FunctionMessage,
65
- HumanMessage,
66
- SystemMessage,
67
- ToolMessage,
68
- is_data_content_block,
69
- )
70
- from langchain_core.messages.ai import UsageMetadata
71
- from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
72
- from langchain_core.output_parsers.base import OutputParserLike
73
- from langchain_core.output_parsers.openai_tools import (
74
- JsonOutputKeyToolsParser,
75
- PydanticToolsParser,
76
- parse_tool_calls,
77
- )
78
- from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
79
- from langchain_core.runnables import Runnable, RunnableConfig, RunnablePassthrough
80
- from langchain_core.tools import BaseTool
81
- from langchain_core.utils import get_pydantic_field_names
82
- from langchain_core.utils.function_calling import convert_to_openai_tool
83
- from langchain_core.utils.utils import _build_model_kwargs
84
- from pydantic import (
85
- BaseModel,
86
- ConfigDict,
87
- Field,
88
- SecretStr,
89
- model_validator,
90
- )
91
- from tenacity import (
92
- before_sleep_log,
93
- retry,
94
- retry_if_exception_type,
95
- stop_after_attempt,
96
- wait_exponential,
97
- )
98
- from typing_extensions import Self, is_typeddict
99
-
100
- from langchain_google_genai._common import (
101
- GoogleGenerativeAIError,
102
- SafetySettingDict,
103
- _BaseGoogleGenerativeAI,
104
- get_client_info,
105
- )
106
- from langchain_google_genai._function_utils import (
107
- _tool_choice_to_tool_config,
108
- _ToolChoiceType,
109
- _ToolConfigDict,
110
- _ToolDict,
111
- convert_to_genai_function_declarations,
112
- is_basemodel_subclass_safe,
113
- tool_to_dict,
114
- )
115
- from langchain_google_genai._image_utils import (
116
- ImageBytesLoader,
117
- image_bytes_to_b64_string,
118
- )
119
-
120
- from . import _genai_extension as genaix
121
-
122
- logger = logging.getLogger(__name__)
123
-
124
-
125
- _FunctionDeclarationType = Union[
126
- FunctionDeclaration,
127
- dict[str, Any],
128
- Callable[..., Any],
129
- ]
130
-
131
-
132
- class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
133
- """
134
- Custom exception class for errors associated with the `Google GenAI` API.
135
-
136
- This exception is raised when there are specific issues related to the
137
- Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
138
- message types or roles.
139
- """
140
-
141
-
142
- def _create_retry_decorator() -> Callable[[Any], Any]:
143
- """
144
- Creates and returns a preconfigured tenacity retry decorator.
145
-
146
- The retry decorator is configured to handle specific Google API exceptions
147
- such as ResourceExhausted and ServiceUnavailable. It uses an exponential
148
- backoff strategy for retries.
149
-
150
- Returns:
151
- Callable[[Any], Any]: A retry decorator configured for handling specific
152
- Google API exceptions.
153
- """
154
- multiplier = 2
155
- min_seconds = 1
156
- max_seconds = 60
157
- max_retries = 2
158
-
159
- return retry(
160
- reraise=True,
161
- stop=stop_after_attempt(max_retries),
162
- wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
163
- retry=(
164
- retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
165
- | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
166
- | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
167
- ),
168
- before_sleep=before_sleep_log(logger, logging.WARNING),
169
- )
170
-
171
-
172
- def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
173
- """
174
- Executes a chat generation method with retry logic using tenacity.
175
-
176
- This function is a wrapper that applies a retry mechanism to a provided
177
- chat generation function. It is useful for handling intermittent issues
178
- like network errors or temporary service unavailability.
179
-
180
- Args:
181
- generation_method (Callable): The chat generation method to be executed.
182
- **kwargs (Any): Additional keyword arguments to pass to the generation method.
183
-
184
- Returns:
185
- Any: The result from the chat generation method.
186
- """
187
- retry_decorator = _create_retry_decorator()
188
-
189
- @retry_decorator
190
- def _chat_with_retry(**kwargs: Any) -> Any:
191
- try:
192
- return generation_method(**kwargs)
193
- # Do not retry for these errors.
194
- except google.api_core.exceptions.FailedPrecondition as exc:
195
- if "location is not supported" in exc.message:
196
- error_msg = (
197
- "Your location is not supported by google-generativeai "
198
- "at the moment. Try to use ChatVertexAI LLM from "
199
- "langchain_google_vertexai."
200
- )
201
- raise ValueError(error_msg)
202
-
203
- except google.api_core.exceptions.InvalidArgument as e:
204
- raise ChatGoogleGenerativeAIError(
205
- f"Invalid argument provided to Gemini: {e}"
206
- ) from e
207
- except Exception as e:
208
- raise e
209
-
210
- return _chat_with_retry(**kwargs)
211
-
212
-
213
- async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
214
- """
215
- Executes a chat generation method with retry logic using tenacity.
216
-
217
- This function is a wrapper that applies a retry mechanism to a provided
218
- chat generation function. It is useful for handling intermittent issues
219
- like network errors or temporary service unavailability.
220
-
221
- Args:
222
- generation_method (Callable): The chat generation method to be executed.
223
- **kwargs (Any): Additional keyword arguments to pass to the generation method.
224
-
225
- Returns:
226
- Any: The result from the chat generation method.
227
- """
228
- retry_decorator = _create_retry_decorator()
229
- from google.api_core.exceptions import InvalidArgument # type: ignore
230
-
231
- @retry_decorator
232
- async def _achat_with_retry(**kwargs: Any) -> Any:
233
- try:
234
- return await generation_method(**kwargs)
235
- except InvalidArgument as e:
236
- # Do not retry for these errors.
237
- raise ChatGoogleGenerativeAIError(
238
- f"Invalid argument provided to Gemini: {e}"
239
- ) from e
240
- except Exception as e:
241
- raise e
242
-
243
- return await _achat_with_retry(**kwargs)
244
-
245
-
246
- def _is_lc_content_block(part: dict) -> bool:
247
- return "type" in part
248
-
249
-
250
- def _is_openai_image_block(block: dict) -> bool:
251
- """Check if the block contains image data in OpenAI Chat Completions format."""
252
- if block.get("type") == "image_url":
253
- if (
254
- (set(block.keys()) <= {"type", "image_url", "detail"})
255
- and (image_url := block.get("image_url"))
256
- and isinstance(image_url, dict)
257
- ):
258
- url = image_url.get("url")
259
- if isinstance(url, str):
260
- return True
261
- else:
262
- return False
263
-
264
- return False
265
 
 
 
 
 
 
 
266
 
267
- def _convert_to_parts(
268
- raw_content: Union[str, Sequence[Union[str, dict]]],
269
- ) -> List[Part]:
270
- """Converts a list of LangChain messages into a google parts."""
271
- parts = []
272
- content = [raw_content] if isinstance(raw_content, str) else raw_content
273
- image_loader = ImageBytesLoader()
274
- for part in content:
275
- if isinstance(part, str):
276
- parts.append(Part(text=part))
277
- elif isinstance(part, Mapping):
278
- if _is_lc_content_block(part):
279
- if part["type"] == "text":
280
- parts.append(Part(text=part["text"]))
281
- elif is_data_content_block(part):
282
- if part["source_type"] == "url":
283
- bytes_ = image_loader._bytes_from_url(part["url"])
284
- elif part["source_type"] == "base64":
285
- bytes_ = base64.b64decode(part["data"])
286
- else:
287
- raise ValueError("source_type must be url or base64.")
288
- inline_data: dict = {"data": bytes_}
289
- if "mime_type" in part:
290
- inline_data["mime_type"] = part["mime_type"]
291
- else:
292
- source = cast(str, part.get("url") or part.get("data"))
293
- mime_type, _ = mimetypes.guess_type(source)
294
- if not mime_type:
295
- kind = filetype.guess(bytes_)
296
- if kind:
297
- mime_type = kind.mime
298
- if mime_type:
299
- inline_data["mime_type"] = mime_type
300
- parts.append(Part(inline_data=inline_data))
301
- elif part["type"] == "image_url":
302
- img_url = part["image_url"]
303
- if isinstance(img_url, dict):
304
- if "url" not in img_url:
305
- raise ValueError(
306
- f"Unrecognized message image format: {img_url}"
307
- )
308
- img_url = img_url["url"]
309
- parts.append(image_loader.load_part(img_url))
310
- # Handle media type like LangChain.js
311
- # https://github.com/langchain-ai/langchainjs/blob/e536593e2585f1dd7b0afc187de4d07cb40689ba/libs/langchain-google-common/src/utils/gemini.ts#L93-L106
312
- elif part["type"] == "media":
313
- if "mime_type" not in part:
314
- raise ValueError(f"Missing mime_type in media part: {part}")
315
- mime_type = part["mime_type"]
316
- media_part = Part()
317
 
318
- if "data" in part:
319
- media_part.inline_data = Blob(
320
- data=part["data"], mime_type=mime_type
321
- )
322
- elif "file_uri" in part:
323
- media_part.file_data = FileData(
324
- file_uri=part["file_uri"], mime_type=mime_type
325
- )
326
- else:
327
- raise ValueError(
328
- f"Media part must have either data or file_uri: {part}"
329
- )
330
- if "video_metadata" in part:
331
- metadata = VideoMetadata(part["video_metadata"])
332
- media_part.video_metadata = metadata
333
- parts.append(media_part)
334
- else:
335
- raise ValueError(
336
- f"Unrecognized message part type: {part['type']}. Only text, "
337
- f"image_url, and media types are supported."
338
- )
339
- else:
340
- # Yolo
341
- logger.warning(
342
- "Unrecognized message part format. Assuming it's a text part."
343
- )
344
- parts.append(Part(text=str(part)))
345
- else:
346
- # TODO: Maybe some of Google's native stuff
347
- # would hit this branch.
348
- raise ChatGoogleGenerativeAIError(
349
- "Gemini only supports text and inline_data parts."
350
- )
351
- return parts
352
-
353
-
354
- def _convert_tool_message_to_parts(
355
- message: ToolMessage | FunctionMessage, name: Optional[str] = None
356
- ) -> list[Part]:
357
- """Converts a tool or function message to a google part."""
358
- # Legacy agent stores tool name in message.additional_kwargs instead of message.name
359
- name = message.name or name or message.additional_kwargs.get("name")
360
- response: Any
361
- parts: list[Part] = []
362
- if isinstance(message.content, list):
363
- media_blocks = []
364
- other_blocks = []
365
- for block in message.content:
366
- if isinstance(block, dict) and (
367
- is_data_content_block(block) or _is_openai_image_block(block)
368
- ):
369
- media_blocks.append(block)
370
- else:
371
- other_blocks.append(block)
372
- parts.extend(_convert_to_parts(media_blocks))
373
- response = other_blocks
374
-
375
- elif not isinstance(message.content, str):
376
- response = message.content
377
- else:
378
- try:
379
- response = json.loads(message.content)
380
- except json.JSONDecodeError:
381
- response = message.content # leave as str representation
382
- part = Part(
383
- function_response=FunctionResponse(
384
- name=name,
385
- response=(
386
- {"output": response} if not isinstance(response, dict) else response
387
- ),
388
- )
389
- )
390
- parts.append(part)
391
- return parts
392
-
393
-
394
- def _get_ai_message_tool_messages_parts(
395
- tool_messages: Sequence[ToolMessage], ai_message: AIMessage
396
- ) -> list[Part]:
397
- """
398
- Finds relevant tool messages for the AI message and converts them to a single
399
- list of Parts.
400
- """
401
- # We are interested only in the tool messages that are part of the AI message
402
- tool_calls_ids = {tool_call["id"]: tool_call for tool_call in ai_message.tool_calls}
403
- parts = []
404
- for i, message in enumerate(tool_messages):
405
- if not tool_calls_ids:
406
- break
407
- if message.tool_call_id in tool_calls_ids:
408
- tool_call = tool_calls_ids[message.tool_call_id]
409
- message_parts = _convert_tool_message_to_parts(
410
- message, name=tool_call.get("name")
411
- )
412
- parts.extend(message_parts)
413
- # remove the id from the dict, so that we do not iterate over it again
414
- tool_calls_ids.pop(message.tool_call_id)
415
- return parts
416
-
417
-
418
- def _parse_chat_history(
419
- input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
420
- ) -> Tuple[Optional[Content], List[Content]]:
421
- messages: List[Content] = []
422
-
423
- if convert_system_message_to_human:
424
- warnings.warn("Convert_system_message_to_human will be deprecated!")
425
-
426
- system_instruction: Optional[Content] = None
427
- messages_without_tool_messages = [
428
- message for message in input_messages if not isinstance(message, ToolMessage)
429
- ]
430
- tool_messages = [
431
- message for message in input_messages if isinstance(message, ToolMessage)
432
- ]
433
- for i, message in enumerate(messages_without_tool_messages):
434
- if isinstance(message, SystemMessage):
435
- system_parts = _convert_to_parts(message.content)
436
- if i == 0:
437
- system_instruction = Content(parts=system_parts)
438
- elif system_instruction is not None:
439
- system_instruction.parts.extend(system_parts)
440
- else:
441
- pass
442
- continue
443
- elif isinstance(message, AIMessage):
444
- role = "model"
445
- if message.tool_calls:
446
- ai_message_parts = []
447
- for tool_call in message.tool_calls:
448
- function_call = FunctionCall(
449
- {
450
- "name": tool_call["name"],
451
- "args": tool_call["args"],
452
- }
453
- )
454
- ai_message_parts.append(Part(function_call=function_call))
455
- tool_messages_parts = _get_ai_message_tool_messages_parts(
456
- tool_messages=tool_messages, ai_message=message
457
- )
458
- messages.append(Content(role=role, parts=ai_message_parts))
459
- messages.append(Content(role="user", parts=tool_messages_parts))
460
- continue
461
- elif raw_function_call := message.additional_kwargs.get("function_call"):
462
- function_call = FunctionCall(
463
- {
464
- "name": raw_function_call["name"],
465
- "args": json.loads(raw_function_call["arguments"]),
466
- }
467
- )
468
- parts = [Part(function_call=function_call)]
469
- else:
470
- parts = _convert_to_parts(message.content)
471
- elif isinstance(message, HumanMessage):
472
- role = "user"
473
- parts = _convert_to_parts(message.content)
474
- if i == 1 and convert_system_message_to_human and system_instruction:
475
- parts = [p for p in system_instruction.parts] + parts
476
- system_instruction = None
477
- elif isinstance(message, FunctionMessage):
478
- role = "user"
479
- parts = _convert_tool_message_to_parts(message)
480
- else:
481
- raise ValueError(
482
- f"Unexpected message with type {type(message)} at the position {i}."
483
- )
484
-
485
- messages.append(Content(role=role, parts=parts))
486
- return system_instruction, messages
487
-
488
-
489
- def _parse_response_candidate(
490
- response_candidate: Candidate, streaming: bool = False
491
- ) -> AIMessage:
492
- content: Union[None, str, List[Union[str, dict]]] = None
493
- additional_kwargs = {}
494
- tool_calls = []
495
- invalid_tool_calls = []
496
- tool_call_chunks = []
497
-
498
- for part in response_candidate.content.parts:
499
- try:
500
- text: Optional[str] = part.text
501
- # Remove erroneous newline character if present
502
- if not streaming and text is not None:
503
- text = text.rstrip("\n")
504
- except AttributeError:
505
- text = None
506
-
507
- if part.thought:
508
- thinking_message = {
509
- "type": "thinking",
510
- "thinking": part.text,
511
- }
512
- if not content:
513
- content = [thinking_message]
514
- elif isinstance(content, str):
515
- content = [thinking_message, content]
516
- elif isinstance(content, list):
517
- content.append(thinking_message)
518
- else:
519
- raise Exception("Unexpected content type")
520
-
521
- elif text is not None:
522
- if not content:
523
- content = text
524
- elif isinstance(content, str) and text:
525
- content = [content, text]
526
- elif isinstance(content, list) and text:
527
- content.append(text)
528
- elif text:
529
- raise Exception("Unexpected content type")
530
-
531
- if hasattr(part, "executable_code") and part.executable_code is not None:
532
- if part.executable_code.code and part.executable_code.language:
533
- code_message = {
534
- "type": "executable_code",
535
- "executable_code": part.executable_code.code,
536
- "language": part.executable_code.language,
537
- }
538
- if not content:
539
- content = [code_message]
540
- elif isinstance(content, str):
541
- content = [content, code_message]
542
- elif isinstance(content, list):
543
- content.append(code_message)
544
- else:
545
- raise Exception("Unexpected content type")
546
-
547
- if (
548
- hasattr(part, "code_execution_result")
549
- and part.code_execution_result is not None
550
- ):
551
- if part.code_execution_result.output:
552
- execution_result = {
553
- "type": "code_execution_result",
554
- "code_execution_result": part.code_execution_result.output,
555
- }
556
-
557
- if not content:
558
- content = [execution_result]
559
- elif isinstance(content, str):
560
- content = [content, execution_result]
561
- elif isinstance(content, list):
562
- content.append(execution_result)
563
- else:
564
- raise Exception("Unexpected content type")
565
 
566
- if part.inline_data.mime_type.startswith("image/"):
567
- image_format = part.inline_data.mime_type[6:]
568
- message = {
569
- "type": "image_url",
570
- "image_url": {
571
- "url": image_bytes_to_b64_string(
572
- part.inline_data.data, image_format=image_format
573
- )
574
- },
575
- }
576
-
577
- if not content:
578
- content = [message]
579
- elif isinstance(content, str) and message:
580
- content = [content, message]
581
- elif isinstance(content, list) and message:
582
- content.append(message)
583
- elif message:
584
- raise Exception("Unexpected content type")
585
-
586
- if part.function_call:
587
- function_call = {"name": part.function_call.name}
588
- # dump to match other function calling llm for now
589
- function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
590
- function_call["arguments"] = json.dumps(
591
- {k: function_call_args_dict[k] for k in function_call_args_dict}
592
- )
593
- additional_kwargs["function_call"] = function_call
594
-
595
- if streaming:
596
- tool_call_chunks.append(
597
- tool_call_chunk(
598
- name=function_call.get("name"),
599
- args=function_call.get("arguments"),
600
- id=function_call.get("id", str(uuid.uuid4())),
601
- index=function_call.get("index"), # type: ignore
602
- )
603
- )
604
- else:
605
- try:
606
- tool_call_dict = parse_tool_calls(
607
- [{"function": function_call}],
608
- return_id=False,
609
- )[0]
610
- except Exception as e:
611
- invalid_tool_calls.append(
612
- invalid_tool_call(
613
- name=function_call.get("name"),
614
- args=function_call.get("arguments"),
615
- id=function_call.get("id", str(uuid.uuid4())),
616
- error=str(e),
617
- )
618
- )
619
- else:
620
- tool_calls.append(
621
- tool_call(
622
- name=tool_call_dict["name"],
623
- args=tool_call_dict["args"],
624
- id=tool_call_dict.get("id", str(uuid.uuid4())),
625
- )
626
- )
627
- if content is None:
628
- content = ""
629
- if any(isinstance(item, dict) and "executable_code" in item for item in content):
630
- warnings.warn(
631
- """
632
- ⚠️ Warning: Output may vary each run.
633
- - 'executable_code': Always present.
634
- - 'execution_result' & 'image_url': May be absent for some queries.
635
-
636
- Validate before using in production.
637
  """
638
- )
639
 
640
- if streaming:
641
- return AIMessageChunk(
642
- content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
643
- additional_kwargs=additional_kwargs,
644
- tool_call_chunks=tool_call_chunks,
645
- )
646
-
647
- return AIMessage(
648
- content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
649
- additional_kwargs=additional_kwargs,
650
- tool_calls=tool_calls,
651
- invalid_tool_calls=invalid_tool_calls,
652
- )
653
-
654
-
655
- def _response_to_result(
656
- response: GenerateContentResponse,
657
- stream: bool = False,
658
- prev_usage: Optional[UsageMetadata] = None,
659
- ) -> ChatResult:
660
- """Converts a PaLM API response into a LangChain ChatResult."""
661
- llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
662
-
663
- # previous usage metadata needs to be subtracted because gemini api returns
664
- # already-accumulated token counts with each chunk
665
- prev_input_tokens = prev_usage["input_tokens"] if prev_usage else 0
666
- prev_output_tokens = prev_usage["output_tokens"] if prev_usage else 0
667
- prev_total_tokens = prev_usage["total_tokens"] if prev_usage else 0
668
 
669
- # Get usage metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  try:
671
- input_tokens = response.usage_metadata.prompt_token_count
672
- output_tokens = response.usage_metadata.candidates_token_count
673
- total_tokens = response.usage_metadata.total_token_count
674
- thought_tokens = response.usage_metadata.thoughts_token_count
675
- cache_read_tokens = response.usage_metadata.cached_content_token_count
676
- if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
677
- if thought_tokens > 0:
678
- lc_usage = UsageMetadata(
679
- input_tokens=input_tokens - prev_input_tokens,
680
- output_tokens=output_tokens - prev_output_tokens,
681
- total_tokens=total_tokens - prev_total_tokens,
682
- input_token_details={"cache_read": cache_read_tokens},
683
- output_token_details={"reasoning": thought_tokens},
684
- )
685
- else:
686
- lc_usage = UsageMetadata(
687
- input_tokens=input_tokens - prev_input_tokens,
688
- output_tokens=output_tokens - prev_output_tokens,
689
- total_tokens=total_tokens - prev_total_tokens,
690
- input_token_details={"cache_read": cache_read_tokens},
691
- )
692
- else:
693
- lc_usage = None
694
- except AttributeError:
695
- lc_usage = None
696
-
697
- generations: List[ChatGeneration] = []
698
-
699
- for candidate in response.candidates:
700
- generation_info = {}
701
- if candidate.finish_reason:
702
- generation_info["finish_reason"] = candidate.finish_reason.name
703
- # Add model_name in last chunk
704
- generation_info["model_name"] = response.model_version
705
- generation_info["safety_ratings"] = [
706
- proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
707
- for safety_rating in candidate.safety_ratings
708
- ]
709
- try:
710
- if candidate.grounding_metadata:
711
- generation_info["grounding_metadata"] = proto.Message.to_dict(
712
- candidate.grounding_metadata
713
- )
714
- except AttributeError:
715
- pass
716
- message = _parse_response_candidate(candidate, streaming=stream)
717
- message.usage_metadata = lc_usage
718
- if stream:
719
- generations.append(
720
- ChatGenerationChunk(
721
- message=cast(AIMessageChunk, message),
722
- generation_info=generation_info,
723
- )
724
- )
725
- else:
726
- generations.append(
727
- ChatGeneration(message=message, generation_info=generation_info)
728
- )
729
- if not response.candidates:
730
- # Likely a "prompt feedback" violation (e.g., toxic input)
731
- # Raising an error would be different than how OpenAI handles it,
732
- # so we'll just log a warning and continue with an empty message.
733
- logger.warning(
734
- "Gemini produced an empty response. Continuing with empty message\n"
735
- f"Feedback: {response.prompt_feedback}"
736
- )
737
- if stream:
738
- generations = [
739
- ChatGenerationChunk(
740
- message=AIMessageChunk(content=""), generation_info={}
741
- )
742
- ]
743
- else:
744
- generations = [ChatGeneration(message=AIMessage(""), generation_info={})]
745
- return ChatResult(generations=generations, llm_output=llm_output)
746
-
747
-
748
- def _is_event_loop_running() -> bool:
749
  try:
750
- asyncio.get_running_loop()
751
- return True
752
- except RuntimeError:
753
- return False
754
-
755
-
756
- class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
757
- """`Google AI` chat models integration.
758
-
759
- Instantiation:
760
- To use, you must have either:
761
-
762
- 1. The ``GOOGLE_API_KEY`` environment variable set with your API key, or
763
- 2. Pass your API key using the google_api_key kwarg
764
- to the ChatGoogleGenerativeAI constructor.
765
-
766
- .. code-block:: python
767
-
768
- from langchain_google_genai import ChatGoogleGenerativeAI
769
-
770
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001")
771
- llm.invoke("Write me a ballad about LangChain")
772
-
773
- Invoke:
774
- .. code-block:: python
775
-
776
- messages = [
777
- ("system", "Translate the user sentence to French."),
778
- ("human", "I love programming."),
779
- ]
780
- llm.invoke(messages)
781
-
782
- .. code-block:: python
783
-
784
- AIMessage(
785
- content="J'adore programmer. \\n",
786
- response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
787
- id='run-56cecc34-2e54-4b52-a974-337e47008ad2-0',
788
- usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
789
- )
790
-
791
- Stream:
792
- .. code-block:: python
793
-
794
- for chunk in llm.stream(messages):
795
- print(chunk)
796
-
797
- .. code-block:: python
798
-
799
- AIMessageChunk(content='J', response_metadata={'finish_reason': 'STOP', 'safety_ratings': []}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 1, 'total_tokens': 19})
800
- AIMessageChunk(content="'adore programmer. \\n", response_metadata={'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-e905f4f4-58cb-4a10-a960-448a2bb649e3', usage_metadata={'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23})
801
-
802
- .. code-block:: python
803
-
804
- stream = llm.stream(messages)
805
- full = next(stream)
806
- for chunk in stream:
807
- full += chunk
808
- full
809
-
810
- .. code-block:: python
811
-
812
- AIMessageChunk(
813
- content="J'adore programmer. \\n",
814
- response_metadata={'finish_reason': 'STOPSTOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]},
815
- id='run-3ce13a42-cd30-4ad7-a684-f1f0b37cdeec',
816
- usage_metadata={'input_tokens': 36, 'output_tokens': 6, 'total_tokens': 42}
817
- )
818
-
819
- Async:
820
- .. code-block:: python
821
-
822
- await llm.ainvoke(messages)
823
-
824
- # stream:
825
- # async for chunk in (await llm.astream(messages))
826
-
827
- # batch:
828
- # await llm.abatch([messages])
829
-
830
- Context Caching:
831
- Context caching allows you to store and reuse content (e.g., PDFs, images) for faster processing.
832
- The `cached_content` parameter accepts a cache name created via the Google Generative AI API.
833
- Below are two examples: caching a single file directly and caching multiple files using `Part`.
834
-
835
- Single File Example:
836
- This caches a single file and queries it.
837
-
838
- .. code-block:: python
839
-
840
- from google import genai
841
- from google.genai import types
842
- import time
843
- from langchain_google_genai import ChatGoogleGenerativeAI
844
- from langchain_core.messages import HumanMessage
845
-
846
- client = genai.Client()
847
-
848
- # Upload file
849
- file = client.files.upload(file="./example_file")
850
- while file.state.name == 'PROCESSING':
851
- time.sleep(2)
852
- file = client.files.get(name=file.name)
853
-
854
- # Create cache
855
- model = 'models/gemini-1.5-flash-latest'
856
- cache = client.caches.create(
857
- model=model,
858
- config=types.CreateCachedContentConfig(
859
- display_name='Cached Content',
860
- system_instruction=(
861
- 'You are an expert content analyzer, and your job is to answer '
862
- 'the user\'s query based on the file you have access to.'
863
- ),
864
- contents=[file],
865
- ttl="300s",
866
- )
867
- )
868
-
869
- # Query with LangChain
870
- llm = ChatGoogleGenerativeAI(
871
- model=model,
872
- cached_content=cache.name,
873
- )
874
- message = HumanMessage(content="Summarize the main points of the content.")
875
- llm.invoke([message])
876
-
877
- Multiple Files Example:
878
- This caches two files using `Part` and queries them together.
879
-
880
- .. code-block:: python
881
-
882
- from google import genai
883
- from google.genai.types import CreateCachedContentConfig, Content, Part
884
- import time
885
- from langchain_google_genai import ChatGoogleGenerativeAI
886
- from langchain_core.messages import HumanMessage
887
-
888
- client = genai.Client()
889
-
890
- # Upload files
891
- file_1 = client.files.upload(file="./file1")
892
- while file_1.state.name == 'PROCESSING':
893
- time.sleep(2)
894
- file_1 = client.files.get(name=file_1.name)
895
-
896
- file_2 = client.files.upload(file="./file2")
897
- while file_2.state.name == 'PROCESSING':
898
- time.sleep(2)
899
- file_2 = client.files.get(name=file_2.name)
900
-
901
- # Create cache with multiple files
902
- contents = [
903
- Content(
904
- role="user",
905
- parts=[
906
- Part.from_uri(file_uri=file_1.uri, mime_type=file_1.mime_type),
907
- Part.from_uri(file_uri=file_2.uri, mime_type=file_2.mime_type),
908
- ],
909
- )
910
- ]
911
- model = "gemini-1.5-flash-latest"
912
- cache = client.caches.create(
913
- model=model,
914
- config=CreateCachedContentConfig(
915
- display_name='Cached Contents',
916
- system_instruction=(
917
- 'You are an expert content analyzer, and your job is to answer '
918
- 'the user\'s query based on the files you have access to.'
919
- ),
920
- contents=contents,
921
- ttl="300s",
922
- )
923
- )
924
-
925
- # Query with LangChain
926
  llm = ChatGoogleGenerativeAI(
927
- model=model,
928
- cached_content=cache.name,
 
929
  )
930
- message = HumanMessage(content="Provide a summary of the key information across both files.")
931
- llm.invoke([message])
932
-
933
- Tool calling:
934
- .. code-block:: python
935
-
936
- from pydantic import BaseModel, Field
937
-
938
-
939
- class GetWeather(BaseModel):
940
- '''Get the current weather in a given location'''
941
-
942
- location: str = Field(
943
- ..., description="The city and state, e.g. San Francisco, CA"
944
- )
945
-
946
-
947
- class GetPopulation(BaseModel):
948
- '''Get the current population in a given location'''
949
-
950
- location: str = Field(
951
- ..., description="The city and state, e.g. San Francisco, CA"
952
- )
953
-
954
-
955
- llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
956
- ai_msg = llm_with_tools.invoke(
957
- "Which city is hotter today and which is bigger: LA or NY?"
958
- )
959
- ai_msg.tool_calls
960
-
961
- .. code-block:: python
962
-
963
- [{'name': 'GetWeather',
964
- 'args': {'location': 'Los Angeles, CA'},
965
- 'id': 'c186c99f-f137-4d52-947f-9e3deabba6f6'},
966
- {'name': 'GetWeather',
967
- 'args': {'location': 'New York City, NY'},
968
- 'id': 'cebd4a5d-e800-4fa5-babd-4aa286af4f31'},
969
- {'name': 'GetPopulation',
970
- 'args': {'location': 'Los Angeles, CA'},
971
- 'id': '4f92d897-f5e4-4d34-a3bc-93062c92591e'},
972
- {'name': 'GetPopulation',
973
- 'args': {'location': 'New York City, NY'},
974
- 'id': '634582de-5186-4e4b-968b-f192f0a93678'}]
975
-
976
- Use Search with Gemini 2:
977
- .. code-block:: python
978
-
979
- from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
980
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
981
- resp = llm.invoke(
982
- "When is the next total solar eclipse in US?",
983
- tools=[GenAITool(google_search={})],
984
- )
985
-
986
- Structured output:
987
- .. code-block:: python
988
-
989
- from typing import Optional
990
-
991
- from pydantic import BaseModel, Field
992
-
993
-
994
- class Joke(BaseModel):
995
- '''Joke to tell user.'''
996
-
997
- setup: str = Field(description="The setup of the joke")
998
- punchline: str = Field(description="The punchline to the joke")
999
- rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
1000
-
1001
-
1002
- structured_llm = llm.with_structured_output(Joke)
1003
- structured_llm.invoke("Tell me a joke about cats")
1004
-
1005
- .. code-block:: python
1006
-
1007
- Joke(
1008
- setup='Why are cats so good at video games?',
1009
- punchline='They have nine lives on the internet',
1010
- rating=None
1011
- )
1012
-
1013
- Image input:
1014
- .. code-block:: python
1015
-
1016
- import base64
1017
- import httpx
1018
- from langchain_core.messages import HumanMessage
1019
-
1020
- image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
1021
- image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
1022
- message = HumanMessage(
1023
- content=[
1024
- {"type": "text", "text": "describe the weather in this image"},
1025
- {
1026
- "type": "image_url",
1027
- "image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
1028
- },
1029
- ]
1030
  )
1031
- ai_msg = llm.invoke([message])
1032
- ai_msg.content
1033
-
1034
- .. code-block:: python
1035
-
1036
- 'The weather in this image appears to be sunny and pleasant. The sky is a bright blue with scattered white clouds, suggesting fair weather. The lush green grass and trees indicate a warm and possibly slightly breezy day. There are no signs of rain or storms.'
1037
-
1038
- Token usage:
1039
- .. code-block:: python
1040
-
1041
- ai_msg = llm.invoke(messages)
1042
- ai_msg.usage_metadata
1043
-
1044
- .. code-block:: python
1045
-
1046
- {'input_tokens': 18, 'output_tokens': 5, 'total_tokens': 23}
1047
-
1048
-
1049
- Response metadata
1050
- .. code-block:: python
1051
-
1052
- ai_msg = llm.invoke(messages)
1053
- ai_msg.response_metadata
1054
-
1055
- .. code-block:: python
1056
-
1057
- {
1058
- 'prompt_feedback': {'block_reason': 0, 'safety_ratings': []},
1059
- 'finish_reason': 'STOP',
1060
- 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]
1061
- }
1062
-
1063
- """ # noqa: E501
1064
-
1065
- client: Any = Field(default=None, exclude=True) #: :meta private:
1066
- async_client_running: Any = Field(default=None, exclude=True) #: :meta private:
1067
- default_metadata: Sequence[Tuple[str, str]] = Field(
1068
- default_factory=list
1069
- ) #: :meta private:
1070
-
1071
- convert_system_message_to_human: bool = False
1072
- """Whether to merge any leading SystemMessage into the following HumanMessage.
1073
-
1074
- Gemini does not support system messages; any unsupported messages will
1075
- raise an error."""
1076
-
1077
- cached_content: Optional[str] = None
1078
- """The name of the cached content used as context to serve the prediction.
1079
-
1080
- Note: only used in explicit caching, where users can have control over caching
1081
- (e.g. what content to cache) and enjoy guaranteed cost savings. Format:
1082
- ``cachedContents/{cachedContent}``.
1083
- """
1084
-
1085
- model_kwargs: dict[str, Any] = Field(default_factory=dict)
1086
- """Holds any unexpected initialization parameters."""
1087
-
1088
- def __init__(self, **kwargs: Any) -> None:
1089
- """Needed for arg validation."""
1090
- # Get all valid field names, including aliases
1091
- valid_fields = set()
1092
- for field_name, field_info in self.__class__.model_fields.items():
1093
- valid_fields.add(field_name)
1094
- if hasattr(field_info, "alias") and field_info.alias is not None:
1095
- valid_fields.add(field_info.alias)
1096
-
1097
- # Check for unrecognized arguments
1098
- for arg in kwargs:
1099
- if arg not in valid_fields:
1100
- suggestions = get_close_matches(arg, valid_fields, n=1)
1101
- suggestion = (
1102
- f" Did you mean: '{suggestions[0]}'?" if suggestions else ""
1103
- )
1104
- logger.warning(
1105
- f"Unexpected argument '{arg}' "
1106
- f"provided to ChatGoogleGenerativeAI.{suggestion}"
1107
- )
1108
- super().__init__(**kwargs)
1109
-
1110
- model_config = ConfigDict(
1111
- populate_by_name=True,
1112
- )
1113
-
1114
- @property
1115
- def lc_secrets(self) -> Dict[str, str]:
1116
- return {"google_api_key": "GOOGLE_API_KEY"}
1117
-
1118
- @property
1119
- def _llm_type(self) -> str:
1120
- return "chat-google-generative-ai"
1121
-
1122
- @property
1123
- def _supports_code_execution(self) -> bool:
1124
- return (
1125
- "gemini-1.5-pro" in self.model
1126
- or "gemini-1.5-flash" in self.model
1127
- or "gemini-2" in self.model
1128
- )
1129
-
1130
- @classmethod
1131
- def is_lc_serializable(self) -> bool:
1132
- return True
1133
-
1134
- @model_validator(mode="before")
1135
- @classmethod
1136
- def build_extra(cls, values: dict[str, Any]) -> Any:
1137
- """Build extra kwargs from additional params that were passed in."""
1138
- all_required_field_names = get_pydantic_field_names(cls)
1139
- values = _build_model_kwargs(values, all_required_field_names)
1140
- return values
1141
-
1142
- @model_validator(mode="after")
1143
- def validate_environment(self) -> Self:
1144
- """Validates params and passes them to google-generativeai package."""
1145
- if self.temperature is not None and not 0 <= self.temperature <= 2.0:
1146
- raise ValueError("temperature must be in the range [0.0, 2.0]")
1147
-
1148
- if self.top_p is not None and not 0 <= self.top_p <= 1:
1149
- raise ValueError("top_p must be in the range [0.0, 1.0]")
1150
-
1151
- if self.top_k is not None and self.top_k <= 0:
1152
- raise ValueError("top_k must be positive")
1153
-
1154
- if not any(
1155
- self.model.startswith(prefix) for prefix in ("models/", "tunedModels/")
1156
- ):
1157
- self.model = f"models/{self.model}"
1158
-
1159
- additional_headers = self.additional_headers or {}
1160
- self.default_metadata = tuple(additional_headers.items())
1161
- client_info = get_client_info(f"ChatGoogleGenerativeAI:{self.model}")
1162
- google_api_key = None
1163
- if not self.credentials:
1164
- if isinstance(self.google_api_key, SecretStr):
1165
- google_api_key = self.google_api_key.get_secret_value()
1166
- else:
1167
- google_api_key = self.google_api_key
1168
- transport: Optional[str] = self.transport
1169
- self.client = genaix.build_generative_service(
1170
- credentials=self.credentials,
1171
- api_key=google_api_key,
1172
- client_info=client_info,
1173
- client_options=self.client_options,
1174
- transport=transport,
1175
  )
1176
- self.async_client_running = None
1177
- return self
1178
-
1179
- @property
1180
- def async_client(self) -> v1betaGenerativeServiceAsyncClient:
1181
- google_api_key = None
1182
- if not self.credentials:
1183
- if isinstance(self.google_api_key, SecretStr):
1184
- google_api_key = self.google_api_key.get_secret_value()
1185
- else:
1186
- google_api_key = self.google_api_key
1187
- # NOTE: genaix.build_generative_async_service requires
1188
- # a running event loop, which causes an error
1189
- # when initialized inside a ThreadPoolExecutor.
1190
- # this check ensures that async client is only initialized
1191
- # within an asyncio event loop to avoid the error
1192
- if not self.async_client_running and _is_event_loop_running():
1193
- # async clients don't support "rest" transport
1194
- # https://github.com/googleapis/gapic-generator-python/issues/1962
1195
- transport = self.transport
1196
- if transport == "rest":
1197
- transport = "grpc_asyncio"
1198
- self.async_client_running = genaix.build_generative_async_service(
1199
- credentials=self.credentials,
1200
- api_key=google_api_key,
1201
- client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"),
1202
- client_options=self.client_options,
1203
- transport=transport,
1204
- )
1205
- return self.async_client_running
1206
-
1207
- @property
1208
- def _identifying_params(self) -> Dict[str, Any]:
1209
- """Get the identifying parameters."""
1210
  return {
1211
- "model": self.model,
1212
- "temperature": self.temperature,
1213
- "top_k": self.top_k,
1214
- "n": self.n,
1215
- "safety_settings": self.safety_settings,
1216
- "response_modalities": self.response_modalities,
1217
- "thinking_budget": self.thinking_budget,
1218
- "include_thoughts": self.include_thoughts,
1219
  }
1220
-
1221
- def invoke(
1222
- self,
1223
- input: LanguageModelInput,
1224
- config: Optional[RunnableConfig] = None,
1225
- *,
1226
- code_execution: Optional[bool] = None,
1227
- stop: Optional[list[str]] = None,
1228
- **kwargs: Any,
1229
- ) -> BaseMessage:
1230
- """
1231
- Enable code execution. Supported on: gemini-1.5-pro, gemini-1.5-flash,
1232
- gemini-2.0-flash, and gemini-2.0-pro. When enabled, the model can execute
1233
- code to solve problems.
1234
- """
1235
-
1236
- """Override invoke to add code_execution parameter."""
1237
-
1238
- if code_execution is not None:
1239
- if not self._supports_code_execution:
1240
- raise ValueError(
1241
- f"Code execution is only supported on Gemini 1.5 Pro, \
1242
- Gemini 1.5 Flash, "
1243
- f"Gemini 2.0 Flash, and Gemini 2.0 Pro models. \
1244
- Current model: {self.model}"
1245
- )
1246
- if "tools" not in kwargs:
1247
- code_execution_tool = GoogleTool(code_execution=CodeExecution())
1248
- kwargs["tools"] = [code_execution_tool]
1249
-
1250
- else:
1251
- raise ValueError(
1252
- "Tools are already defined." "code_execution tool can't be defined"
1253
- )
1254
-
1255
- return super().invoke(input, config, stop=stop, **kwargs)
1256
-
1257
- def _get_ls_params(
1258
- self, stop: Optional[List[str]] = None, **kwargs: Any
1259
- ) -> LangSmithParams:
1260
- """Get standard params for tracing."""
1261
- params = self._get_invocation_params(stop=stop, **kwargs)
1262
- models_prefix = "models/"
1263
- ls_model_name = (
1264
- self.model[len(models_prefix) :]
1265
- if self.model and self.model.startswith(models_prefix)
1266
- else self.model
1267
  )
1268
- ls_params = LangSmithParams(
1269
- ls_provider="google_genai",
1270
- ls_model_name=ls_model_name,
1271
- ls_model_type="chat",
1272
- ls_temperature=params.get("temperature", self.temperature),
1273
- )
1274
- if ls_max_tokens := params.get("max_output_tokens", self.max_output_tokens):
1275
- ls_params["ls_max_tokens"] = ls_max_tokens
1276
- if ls_stop := stop or params.get("stop", None):
1277
- ls_params["ls_stop"] = ls_stop
1278
- return ls_params
1279
-
1280
- def _prepare_params(
1281
- self,
1282
- stop: Optional[List[str]],
1283
- generation_config: Optional[Dict[str, Any]] = None,
1284
- ) -> GenerationConfig:
1285
- gen_config = {
1286
- k: v
1287
- for k, v in {
1288
- "candidate_count": self.n,
1289
- "temperature": self.temperature,
1290
- "stop_sequences": stop,
1291
- "max_output_tokens": self.max_output_tokens,
1292
- "top_k": self.top_k,
1293
- "top_p": self.top_p,
1294
- "response_modalities": self.response_modalities,
1295
- "thinking_config": (
1296
- (
1297
- {"thinking_budget": self.thinking_budget}
1298
- if self.thinking_budget is not None
1299
- else {}
1300
- )
1301
- | (
1302
- {"include_thoughts": self.include_thoughts}
1303
- if self.include_thoughts is not None
1304
- else {}
1305
- )
1306
- )
1307
- if self.thinking_budget is not None or self.include_thoughts is not None
1308
- else None,
1309
- }.items()
1310
- if v is not None
1311
  }
1312
- if generation_config:
1313
- gen_config = {**gen_config, **generation_config}
1314
- return GenerationConfig(**gen_config)
1315
-
1316
- def _generate(
1317
- self,
1318
- messages: List[BaseMessage],
1319
- stop: Optional[List[str]] = None,
1320
- run_manager: Optional[CallbackManagerForLLMRun] = None,
1321
- *,
1322
- tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1323
- functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1324
- safety_settings: Optional[SafetySettingDict] = None,
1325
- tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1326
- generation_config: Optional[Dict[str, Any]] = None,
1327
- cached_content: Optional[str] = None,
1328
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1329
- **kwargs: Any,
1330
- ) -> ChatResult:
1331
- request = self._prepare_request(
1332
- messages,
1333
- stop=stop,
1334
- tools=tools,
1335
- functions=functions,
1336
- safety_settings=safety_settings,
1337
- tool_config=tool_config,
1338
- generation_config=generation_config,
1339
- cached_content=cached_content or self.cached_content,
1340
- tool_choice=tool_choice,
1341
- )
1342
- response: GenerateContentResponse = _chat_with_retry(
1343
- request=request,
1344
- **kwargs,
1345
- generation_method=self.client.generate_content,
1346
- metadata=self.default_metadata,
1347
- )
1348
- return _response_to_result(response)
1349
-
1350
- async def _agenerate(
1351
- self,
1352
- messages: List[BaseMessage],
1353
- stop: Optional[List[str]] = None,
1354
- run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
1355
- *,
1356
- tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1357
- functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1358
- safety_settings: Optional[SafetySettingDict] = None,
1359
- tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1360
- generation_config: Optional[Dict[str, Any]] = None,
1361
- cached_content: Optional[str] = None,
1362
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1363
- **kwargs: Any,
1364
- ) -> ChatResult:
1365
- if not self.async_client:
1366
- updated_kwargs = {
1367
- **kwargs,
1368
- **{
1369
- "tools": tools,
1370
- "functions": functions,
1371
- "safety_settings": safety_settings,
1372
- "tool_config": tool_config,
1373
- "generation_config": generation_config,
1374
- },
1375
- }
1376
- return await super()._agenerate(
1377
- messages, stop, run_manager, **updated_kwargs
1378
- )
1379
-
1380
- request = self._prepare_request(
1381
- messages,
1382
- stop=stop,
1383
- tools=tools,
1384
- functions=functions,
1385
- safety_settings=safety_settings,
1386
- tool_config=tool_config,
1387
- generation_config=generation_config,
1388
- cached_content=cached_content or self.cached_content,
1389
- tool_choice=tool_choice,
1390
- )
1391
- response: GenerateContentResponse = await _achat_with_retry(
1392
- request=request,
1393
- **kwargs,
1394
- generation_method=self.async_client.generate_content,
1395
- metadata=self.default_metadata,
1396
- )
1397
- return _response_to_result(response)
1398
-
1399
- def _stream(
1400
- self,
1401
- messages: List[BaseMessage],
1402
- stop: Optional[List[str]] = None,
1403
- run_manager: Optional[CallbackManagerForLLMRun] = None,
1404
- *,
1405
- tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1406
- functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1407
- safety_settings: Optional[SafetySettingDict] = None,
1408
- tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1409
- generation_config: Optional[Dict[str, Any]] = None,
1410
- cached_content: Optional[str] = None,
1411
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1412
- **kwargs: Any,
1413
- ) -> Iterator[ChatGenerationChunk]:
1414
- request = self._prepare_request(
1415
- messages,
1416
- stop=stop,
1417
- tools=tools,
1418
- functions=functions,
1419
- safety_settings=safety_settings,
1420
- tool_config=tool_config,
1421
- generation_config=generation_config,
1422
- cached_content=cached_content or self.cached_content,
1423
- tool_choice=tool_choice,
1424
- )
1425
- response: GenerateContentResponse = _chat_with_retry(
1426
- request=request,
1427
- generation_method=self.client.stream_generate_content,
1428
- **kwargs,
1429
- metadata=self.default_metadata,
1430
- )
1431
-
1432
- prev_usage_metadata: UsageMetadata | None = None
1433
- for chunk in response:
1434
- _chat_result = _response_to_result(
1435
- chunk, stream=True, prev_usage=prev_usage_metadata
1436
- )
1437
- gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1438
- message = cast(AIMessageChunk, gen.message)
1439
 
1440
- curr_usage_metadata: UsageMetadata | dict[str, int] = (
1441
- message.usage_metadata or {}
1442
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1443
 
1444
- prev_usage_metadata = (
1445
- message.usage_metadata
1446
- if prev_usage_metadata is None
1447
- else UsageMetadata(
1448
- input_tokens=prev_usage_metadata.get("input_tokens", 0)
1449
- + curr_usage_metadata.get("input_tokens", 0),
1450
- output_tokens=prev_usage_metadata.get("output_tokens", 0)
1451
- + curr_usage_metadata.get("output_tokens", 0),
1452
- total_tokens=prev_usage_metadata.get("total_tokens", 0)
1453
- + curr_usage_metadata.get("total_tokens", 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1454
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1455
  )
1456
-
1457
- if run_manager:
1458
- run_manager.on_llm_new_token(gen.text)
1459
- yield gen
1460
-
1461
- async def _astream(
1462
- self,
1463
- messages: List[BaseMessage],
1464
- stop: Optional[List[str]] = None,
1465
- run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
1466
- *,
1467
- tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1468
- functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1469
- safety_settings: Optional[SafetySettingDict] = None,
1470
- tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1471
- generation_config: Optional[Dict[str, Any]] = None,
1472
- cached_content: Optional[str] = None,
1473
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1474
- **kwargs: Any,
1475
- ) -> AsyncIterator[ChatGenerationChunk]:
1476
- if not self.async_client:
1477
- updated_kwargs = {
1478
- **kwargs,
1479
- **{
1480
- "tools": tools,
1481
- "functions": functions,
1482
- "safety_settings": safety_settings,
1483
- "tool_config": tool_config,
1484
- "generation_config": generation_config,
1485
- },
1486
- }
1487
- async for value in super()._astream(
1488
- messages, stop, run_manager, **updated_kwargs
1489
- ):
1490
- yield value
1491
  else:
1492
- request = self._prepare_request(
1493
- messages,
1494
- stop=stop,
1495
- tools=tools,
1496
- functions=functions,
1497
- safety_settings=safety_settings,
1498
- tool_config=tool_config,
1499
- generation_config=generation_config,
1500
- cached_content=cached_content or self.cached_content,
1501
- tool_choice=tool_choice,
1502
- )
1503
- prev_usage_metadata: UsageMetadata | None = None
1504
- async for chunk in await _achat_with_retry(
1505
- request=request,
1506
- generation_method=self.async_client.stream_generate_content,
1507
- **kwargs,
1508
- metadata=self.default_metadata,
1509
- ):
1510
- _chat_result = _response_to_result(
1511
- chunk, stream=True, prev_usage=prev_usage_metadata
1512
  )
1513
- gen = cast(ChatGenerationChunk, _chat_result.generations[0])
1514
- message = cast(AIMessageChunk, gen.message)
1515
-
1516
- curr_usage_metadata: UsageMetadata | dict[str, int] = (
1517
- message.usage_metadata or {}
 
 
 
 
 
 
 
 
 
 
 
1518
  )
1519
-
1520
- prev_usage_metadata = (
1521
- message.usage_metadata
1522
- if prev_usage_metadata is None
1523
- else UsageMetadata(
1524
- input_tokens=prev_usage_metadata.get("input_tokens", 0)
1525
- + curr_usage_metadata.get("input_tokens", 0),
1526
- output_tokens=prev_usage_metadata.get("output_tokens", 0)
1527
- + curr_usage_metadata.get("output_tokens", 0),
1528
- total_tokens=prev_usage_metadata.get("total_tokens", 0)
1529
- + curr_usage_metadata.get("total_tokens", 0),
1530
- )
 
 
 
 
 
 
 
 
 
 
 
1531
  )
 
 
 
 
 
 
 
 
 
1532
 
1533
- if run_manager:
1534
- await run_manager.on_llm_new_token(gen.text)
1535
- yield gen
1536
-
1537
- def _prepare_request(
1538
- self,
1539
- messages: List[BaseMessage],
1540
- *,
1541
- stop: Optional[List[str]] = None,
1542
- tools: Optional[Sequence[Union[_ToolDict, GoogleTool]]] = None,
1543
- functions: Optional[Sequence[_FunctionDeclarationType]] = None,
1544
- safety_settings: Optional[SafetySettingDict] = None,
1545
- tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1546
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1547
- generation_config: Optional[Dict[str, Any]] = None,
1548
- cached_content: Optional[str] = None,
1549
- ) -> Tuple[GenerateContentRequest, Dict[str, Any]]:
1550
- if tool_choice and tool_config:
1551
- raise ValueError(
1552
- "Must specify at most one of tool_choice and tool_config, received "
1553
- f"both:\n\n{tool_choice=}\n\n{tool_config=}"
1554
  )
 
 
 
 
 
 
 
 
 
1555
 
1556
- formatted_tools = None
1557
- code_execution_tool = GoogleTool(code_execution=CodeExecution())
1558
- if tools == [code_execution_tool]:
1559
- formatted_tools = tools
1560
- elif tools:
1561
- formatted_tools = [convert_to_genai_function_declarations(tools)]
1562
- elif functions:
1563
- formatted_tools = [convert_to_genai_function_declarations(functions)]
1564
-
1565
- filtered_messages = []
1566
- for message in messages:
1567
- if isinstance(message, HumanMessage) and not message.content:
1568
- warnings.warn(
1569
- "HumanMessage with empty content was removed to prevent API error"
1570
- )
1571
- else:
1572
- filtered_messages.append(message)
1573
- messages = filtered_messages
1574
-
1575
- system_instruction, history = _parse_chat_history(
1576
- messages,
1577
- convert_system_message_to_human=self.convert_system_message_to_human,
1578
- )
1579
- if tool_choice:
1580
- if not formatted_tools:
1581
- msg = (
1582
- f"Received {tool_choice=} but no {tools=}. 'tool_choice' can only "
1583
- f"be specified if 'tools' is specified."
1584
- )
1585
- raise ValueError(msg)
1586
- all_names: List[str] = []
1587
- for t in formatted_tools:
1588
- if hasattr(t, "function_declarations"):
1589
- t_with_declarations = cast(Any, t)
1590
- all_names.extend(
1591
- f.name for f in t_with_declarations.function_declarations
1592
- )
1593
- elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"):
1594
- continue
1595
- else:
1596
- raise TypeError(
1597
- f"Tool {t} doesn't have function_declarations attribute"
1598
- )
1599
-
1600
- tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
1601
-
1602
- formatted_tool_config = None
1603
- if tool_config:
1604
- formatted_tool_config = ToolConfig(
1605
- function_calling_config=tool_config["function_calling_config"]
1606
- )
1607
- formatted_safety_settings = []
1608
- if safety_settings:
1609
- formatted_safety_settings = [
1610
- SafetySetting(category=c, threshold=t)
1611
- for c, t in safety_settings.items()
1612
- ]
1613
- request = GenerateContentRequest(
1614
- model=self.model,
1615
- contents=history,
1616
- tools=formatted_tools,
1617
- tool_config=formatted_tool_config,
1618
- safety_settings=formatted_safety_settings,
1619
- generation_config=self._prepare_params(
1620
- stop, generation_config=generation_config
1621
- ),
1622
- cached_content=cached_content,
1623
- )
1624
- if system_instruction:
1625
- request.system_instruction = system_instruction
1626
 
1627
- return request
 
1628
 
1629
- def get_num_tokens(self, text: str) -> int:
1630
- """Get the number of tokens present in the text.
 
 
1631
 
1632
- Useful for checking if an input will fit in a model's context window.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1633
 
1634
- Args:
1635
- text: The string input to tokenize.
1636
 
1637
- Returns:
1638
- The integer number of tokens in the text.
1639
- """
1640
- result = self.client.count_tokens(
1641
- model=self.model, contents=[Content(parts=[Part(text=text)])]
1642
- )
1643
- return result.total_tokens
1644
 
1645
- def with_structured_output(
1646
- self,
1647
- schema: Union[Dict, Type[BaseModel]],
1648
- *,
1649
- include_raw: bool = False,
1650
- **kwargs: Any,
1651
- ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
1652
- _ = kwargs.pop("method", None)
1653
- _ = kwargs.pop("strict", None)
1654
- if kwargs:
1655
- raise ValueError(f"Received unsupported arguments {kwargs}")
1656
- tool_name = _get_tool_name(schema) # type: ignore[arg-type]
1657
- if isinstance(schema, type) and is_basemodel_subclass_safe(schema):
1658
- parser: OutputParserLike = PydanticToolsParser(
1659
- tools=[schema], first_tool_only=True
1660
- )
1661
  else:
1662
- parser = JsonOutputKeyToolsParser(key_name=tool_name, first_tool_only=True)
1663
- tool_choice = tool_name if self._supports_tool_choice else None
 
 
1664
  try:
1665
- llm = self.bind_tools(
1666
- [schema],
1667
- tool_choice=tool_choice,
1668
- ls_structured_output_format={
1669
- "kwargs": {"method": "function_calling"},
1670
- "schema": convert_to_openai_tool(schema),
1671
- },
1672
- )
1673
- except Exception:
1674
- llm = self.bind_tools([schema], tool_choice=tool_choice)
1675
- if include_raw:
1676
- parser_with_fallback = RunnablePassthrough.assign(
1677
- parsed=itemgetter("raw") | parser, parsing_error=lambda _: None
1678
- ).with_fallbacks(
1679
- [RunnablePassthrough.assign(parsed=lambda _: None)],
1680
- exception_key="parsing_error",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1681
  )
1682
- return {"raw": llm} | parser_with_fallback
1683
- else:
1684
- return llm | parser
1685
-
1686
- def bind_tools(
1687
- self,
1688
- tools: Sequence[
1689
- dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool
1690
- ],
1691
- tool_config: Optional[Union[Dict, _ToolConfigDict]] = None,
1692
- *,
1693
- tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
1694
- **kwargs: Any,
1695
- ) -> Runnable[LanguageModelInput, BaseMessage]:
1696
- """Bind tool-like objects to this chat model.
1697
-
1698
- Assumes model is compatible with google-generativeAI tool-calling API.
1699
-
1700
- Args:
1701
- tools: A list of tool definitions to bind to this chat model.
1702
- Can be a pydantic model, callable, or BaseTool. Pydantic
1703
- models, callables, and BaseTools will be automatically converted to
1704
- their schema dictionary representation.
1705
- **kwargs: Any additional parameters to pass to the
1706
- :class:`~langchain.runnable.Runnable` constructor.
1707
- """
1708
- if tool_choice and tool_config:
1709
- raise ValueError(
1710
- "Must specify at most one of tool_choice and tool_config, received "
1711
- f"both:\n\n{tool_choice=}\n\n{tool_config=}"
1712
  )
1713
- try:
1714
- formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] # type: ignore[arg-type]
1715
- except Exception:
1716
- formatted_tools = [
1717
- tool_to_dict(convert_to_genai_function_declarations(tools))
1718
- ]
1719
- if tool_choice:
1720
- kwargs["tool_choice"] = tool_choice
1721
- elif tool_config:
1722
- kwargs["tool_config"] = tool_config
 
 
 
 
 
 
 
 
 
 
 
1723
  else:
1724
- pass
1725
- return self.bind(tools=formatted_tools, **kwargs)
1726
-
1727
- @property
1728
- def _supports_tool_choice(self) -> bool:
1729
- return (
1730
- "gemini-1.5-pro" in self.model
1731
- or "gemini-1.5-flash" in self.model
1732
- or "gemini-2" in self.model
 
 
 
1733
  )
1734
-
1735
-
1736
- def _get_tool_name(
1737
- tool: Union[_ToolDict, GoogleTool, Dict],
1738
- ) -> str:
1739
- try:
1740
- genai_tool = tool_to_dict(convert_to_genai_function_declarations([tool]))
1741
- return [f["name"] for f in genai_tool["function_declarations"]][0] # type: ignore[index]
1742
- except ValueError as e: # other TypedDict
1743
- if is_typeddict(tool):
1744
- return convert_to_openai_tool(cast(Dict, tool))["function"]["name"]
1745
- else:
1746
- raise e
 
1
+ SYSTEM_PROMPT = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ 1. Air Quality Data (df):
4
+ - Columns: 'Timestamp', 'station', 'PM2.5', 'PM10', 'address', 'city', 'latitude', 'longitude', 'state'
5
+ - Example row: ['2023-01-01', 'StationA', 45.67, 78.9, '123 Main St', 'Mumbai', 19.07, 72.87, 'Maharashtra']
6
+ - Frequency: daily
7
+ - 'pollution' generally means 'PM2.5'.
8
+ - PM2.5 guidelines: India: 60, WHO: 15. PM10 guidelines: India: 100, WHO: 50.
9
 
10
+ 2. NCAP Funding Data (ncap_data):
11
+ - Columns: 'city', 'state', 'funding_received', 'year', 'project', 'status'
12
+ - Example row: ['Mumbai', 'Maharashtra', 10000000, 2022, 'Clean Air Project', 'Ongoing']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ 3. State Population Data (states_data):
15
+ - Columns: 'state', 'population', 'year', 'urban_population', 'rural_population'
16
+ - Example row: ['Maharashtra', 123000000, 2021, 60000000, 63000000]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ You already have these dataframes loaded as df, ncap_data, and states_data. Do not read any files. Use these dataframes to answer questions about air quality, funding, or population. When aggregating, report standard deviation, standard error, and number of data points. Always report units. If a plot is required, follow the previous instructions for saving and reporting plots. If a question is about funding or population, use the relevant dataframe.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
 
20
 
21
+ import os
22
+ import pandas as pd
23
+ from pandasai import Agent, SmartDataframe
24
+ from typing import Tuple
25
+ from PIL import Image
26
+ from pandasai.llm import HuggingFaceTextGen
27
+ from dotenv import load_dotenv
28
+ from langchain_groq import ChatGroq
29
+ from langchain_google_genai import ChatGoogleGenerativeAI
30
+ import matplotlib.pyplot as plt
31
+ import json
32
+ from datetime import datetime
33
+ from dotenv import load_dotenv
34
+
35
+ # FORCE reload environment variables
36
+ load_dotenv(override=True)
37
+ Groq_Token = os.getenv("GROQ_API_KEY")
38
+ hf_token = os.getenv("HF_TOKEN")
39
+ gemini_token = os.getenv("GEMINI_TOKEN")
40
+ import uuid
 
 
 
 
 
 
 
 
41
 
42
+ # FORCE reload environment variables
43
+
44
+ models = {
45
+ "gpt-oss-20b": "openai/gpt-oss-20b",
46
+ "gpt-oss-120b": "openai/gpt-oss-120b",
47
+ "llama3.1": "llama-3.1-8b-instant",
48
+ "llama3.3": "llama-3.3-70b-versatile",
49
+ "deepseek-R1": "deepseek-r1-distill-llama-70b",
50
+ "llama4 maverik":"meta-llama/llama-4-maverick-17b-128e-instruct",
51
+ "llama4 scout":"meta-llama/llama-4-scout-17b-16e-instruct",
52
+ "gemini-pro": "gemini-1.5-pro"
53
+ }
54
+
55
+ def log_interaction(user_query, model_name, response_content, generated_code, execution_time, error_message=None, is_image=False):
56
+ """Log user interactions to Hugging Face dataset"""
57
  try:
58
+ if not hf_token or hf_token.strip() == "":
59
+ print("Warning: HF_TOKEN not available, skipping logging")
60
+ return
61
+
62
+ # Create log entry
63
+ log_entry = {
64
+ "timestamp": datetime.now().isoformat(),
65
+ "session_id": str(uuid.uuid4()),
66
+ "user_query": user_query,
67
+ "model_name": model_name,
68
+ "response_content": str(response_content),
69
+ "generated_code": generated_code or "",
70
+ "execution_time_seconds": execution_time,
71
+ "error_message": error_message or "",
72
+ "is_image_output": is_image,
73
+ "success": error_message is None
74
+ }
75
+
76
+ # Create DataFrame
77
+ df = pd.DataFrame([log_entry])
78
+
79
+ # Create unique filename with timestamp
80
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
81
+ random_id = str(uuid.uuid4())[:8]
82
+ filename = f"interaction_log_{timestamp_str}_{random_id}.parquet"
83
+
84
+ # Save locally first
85
+ local_path = f"/tmp/{filename}"
86
+ df.to_parquet(local_path, index=False)
87
+ # Clean up local file
88
+ if os.path.exists(local_path):
89
+ os.remove(local_path)
90
+ print(f"Successfully logged interaction locally: {filename}")
91
+ except Exception as e:
92
+ print(f"Error logging interaction: {e}")
93
+
94
+ def preprocess_and_load_df(path: str) -> pd.DataFrame:
95
+ """Load and preprocess the dataframe"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
+ df = pd.read_csv(path)
98
+ df["Timestamp"] = pd.to_datetime(df["Timestamp"])
99
+ return df
100
+ except Exception as e:
101
+ raise Exception(f"Error loading dataframe: {e}")
102
+
103
+ def load_smart_df(df: pd.DataFrame, inference_server: str, name="mistral") -> SmartDataframe:
104
+ """Load smart dataframe with error handling"""
105
+ try:
106
+ if name == "gemini-pro":
107
+ if not gemini_token or gemini_token.strip() == "":
108
+ raise ValueError("Gemini API token not available or empty")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  llm = ChatGoogleGenerativeAI(
110
+ model=models[name],
111
+ google_api_key=gemini_token,
112
+ temperature=0.1
113
  )
114
+ else:
115
+ if not Groq_Token or Groq_Token.strip() == "":
116
+ raise ValueError("Groq API token not available or empty")
117
+ llm = ChatGroq(
118
+ model=models[name],
119
+ api_key=Groq_Token,
120
+ temperature=0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
+ smart_df = SmartDataframe(df, config={"llm": llm, "max_retries": 5, "enable_cache": False})
123
+ return smart_df
124
+ except Exception as e:
125
+ raise Exception(f"Error loading smart dataframe: {e}")
126
+ try:
127
+ response = agent.chat(prompt)
128
+ execution_time = (datetime.now() - start_time).total_seconds()
129
+
130
+ gen_code = getattr(agent, 'last_code_generated', '')
131
+ ex_code = getattr(agent, 'last_code_executed', '')
132
+ last_prompt = getattr(agent, 'last_prompt', prompt)
133
+
134
+ # Log the interaction
135
+ log_interaction(
136
+ user_query=prompt,
137
+ model_name="pandas_ai_agent",
138
+ response_content=response,
139
+ generated_code=gen_code,
140
+ execution_time=execution_time,
141
+ error_message=None,
142
+ is_image=isinstance(response, str) and any(response.endswith(ext) for ext in ['.png', '.jpg', '.jpeg'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  )
144
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  return {
146
+ "role": "assistant",
147
+ "content": response,
148
+ "gen_code": gen_code,
149
+ "ex_code": ex_code,
150
+ "last_prompt": last_prompt,
151
+ "error": None
 
 
152
  }
153
+ except Exception as e:
154
+ execution_time = (datetime.now() - start_time).total_seconds()
155
+ error_msg = str(e)
156
+
157
+ # Log the failed interaction
158
+ log_interaction(
159
+ user_query=prompt,
160
+ model_name="pandas_ai_agent",
161
+ response_content=f"Error: {error_msg}",
162
+ generated_code="",
163
+ execution_time=execution_time,
164
+ error_message=error_msg,
165
+ is_image=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
+
168
+ return {
169
+ "role": "assistant",
170
+ "content": f"Error: {error_msg}",
171
+ "gen_code": "",
172
+ "ex_code": "",
173
+ "last_prompt": prompt,
174
+ "error": error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ def decorate_with_code(response: dict) -> str:
178
+ """Decorate response with code details"""
179
+ gen_code = response.get("gen_code", "No code generated")
180
+ last_prompt = response.get("last_prompt", "No prompt")
181
+
182
+ return f"""<details>
183
+ <summary>Generated Code</summary>
184
+
185
+ ```python
186
+ {gen_code}
187
+ ```
188
+ </details>
189
+
190
+ <details>
191
+ <summary>Prompt</summary>
192
+
193
+ {last_prompt}
194
+ """
195
 
196
+ def show_response(st, response):
197
+ """Display response with error handling"""
198
+ try:
199
+ with st.chat_message(response["role"]):
200
+ content = response.get("content", "No content")
201
+
202
+ try:
203
+ # Try to open as image
204
+ image = Image.open(content)
205
+ if response.get("gen_code"):
206
+ st.markdown(decorate_with_code(response), unsafe_allow_html=True)
207
+ st.image(image)
208
+ return {"is_image": True}
209
+ except:
210
+ # Not an image, display as text
211
+ if response.get("gen_code"):
212
+ display_content = decorate_with_code(response) + f"""</details>
213
+
214
+ {content}"""
215
+ else:
216
+ display_content = content
217
+ st.markdown(display_content, unsafe_allow_html=True)
218
+ return {"is_image": False}
219
+ except Exception as e:
220
+ st.error(f"Error displaying response: {e}")
221
+ return {"is_image": False}
222
+
223
+ def ask_question(model_name, question):
224
+ """Ask question with comprehensive error handling and logging"""
225
+ start_time = datetime.now()
226
+ try:
227
+ # Reload environment variables to get fresh values
228
+ load_dotenv(override=True)
229
+ fresh_groq_token = os.getenv("GROQ_API_KEY")
230
+ fresh_gemini_token = os.getenv("GEMINI_TOKEN")
231
+
232
+ print(f"ask_question - Fresh Groq Token: {'Present' if fresh_groq_token else 'Missing'}")
233
+
234
+ # Check API availability with fresh tokens
235
+ if model_name == "gemini-pro":
236
+ if not fresh_gemini_token or fresh_gemini_token.strip() == "":
237
+ execution_time = (datetime.now() - start_time).total_seconds()
238
+ error_msg = "Missing or empty API token"
239
+
240
+ # Log the failed interaction
241
+ log_interaction(
242
+ user_query=question,
243
+ model_name=model_name,
244
+ response_content="❌ Gemini API token not available or empty",
245
+ generated_code="",
246
+ execution_time=execution_time,
247
+ error_message=error_msg,
248
+ is_image=False
249
  )
250
+
251
+ return {
252
+ "role": "assistant",
253
+ "content": "❌ Gemini API token not available or empty. Please set GEMINI_TOKEN in your environment variables.",
254
+ "gen_code": "",
255
+ "ex_code": "",
256
+ "last_prompt": question,
257
+ "error": error_msg
258
+ }
259
+ llm = ChatGoogleGenerativeAI(
260
+ model=models[model_name],
261
+ google_api_key=fresh_gemini_token,
262
+ temperature=0
263
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  else:
265
+ if not fresh_groq_token or fresh_groq_token.strip() == "":
266
+ execution_time = (datetime.now() - start_time).total_seconds()
267
+ error_msg = "Missing or empty API token"
268
+
269
+ # Log the failed interaction
270
+ log_interaction(
271
+ user_query=question,
272
+ model_name=model_name,
273
+ response_content="❌ Groq API token not available or empty",
274
+ generated_code="",
275
+ execution_time=execution_time,
276
+ error_message=error_msg,
277
+ is_image=False
 
 
 
 
 
 
 
278
  )
279
+
280
+ return {
281
+ "role": "assistant",
282
+ "content": "❌ Groq API token not available or empty. Please set GROQ_API_KEY in your environment variables and restart the application.",
283
+ "gen_code": "",
284
+ "ex_code": "",
285
+ "last_prompt": question,
286
+ "error": error_msg
287
+ }
288
+
289
+ # Test the API key by trying to create the client
290
+ try:
291
+ llm = ChatGroq(
292
+ model=models[model_name],
293
+ api_key=fresh_groq_token,
294
+ temperature=0.1
295
  )
296
+ # Test with a simple call to verify the API key works
297
+ test_response = llm.invoke("Test")
298
+ print("API key test successful")
299
+ except Exception as api_error:
300
+ execution_time = (datetime.now() - start_time).total_seconds()
301
+ error_msg = str(api_error)
302
+
303
+ if "organization_restricted" in error_msg.lower() or "unauthorized" in error_msg.lower():
304
+ response_content = "❌ API Key Error: Your Groq API key appears to be invalid, expired, or restricted. Please check your API key in the .env file."
305
+ log_error_msg = f"API key validation failed: {error_msg}"
306
+ else:
307
+ response_content = f"❌ API Connection Error: {error_msg}"
308
+ log_error_msg = error_msg
309
+
310
+ # Log the failed interaction
311
+ log_interaction(
312
+ user_query=question,
313
+ model_name=model_name,
314
+ response_content=response_content,
315
+ generated_code="",
316
+ execution_time=execution_time,
317
+ error_message=log_error_msg,
318
+ is_image=False
319
  )
320
+
321
+ return {
322
+ "role": "assistant",
323
+ "content": response_content,
324
+ "gen_code": "",
325
+ "ex_code": "",
326
+ "last_prompt": question,
327
+ "error": log_error_msg
328
+ }
329
 
330
+ # Check if data file exists
331
+ if not os.path.exists("Data.csv"):
332
+ execution_time = (datetime.now() - start_time).total_seconds()
333
+ error_msg = "Data file not found"
334
+
335
+ # Log the failed interaction
336
+ log_interaction(
337
+ user_query=question,
338
+ model_name=model_name,
339
+ response_content="❌ Data.csv file not found",
340
+ generated_code="",
341
+ execution_time=execution_time,
342
+ error_message=error_msg,
343
+ is_image=False
 
 
 
 
 
 
 
344
  )
345
+
346
+ return {
347
+ "role": "assistant",
348
+ "content": "❌ Data.csv file not found. Please ensure the data file is in the correct location.",
349
+ "gen_code": "",
350
+ "ex_code": "",
351
+ "last_prompt": question,
352
+ "error": error_msg
353
+ }
354
 
355
+ df_check = pd.read_csv("Data.csv")
356
+ df_check["Timestamp"] = pd.to_datetime(df_check["Timestamp"])
357
+ df_check = df_check.head(5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
+ new_line = "\n"
360
+ parameters = {"font.size": 12, "figure.dpi": 600}
361
 
362
+ template = f"""```python
363
+ import pandas as pd
364
+ import matplotlib.pyplot as plt
365
+ import uuid
366
 
367
+ plt.rcParams.update({parameters})
368
+
369
+ df = pd.read_csv("Data.csv")
370
+ df["Timestamp"] = pd.to_datetime(df["Timestamp"])
371
+
372
+ # Available columns and data types:
373
+ {new_line.join(map(lambda x: '# '+x, str(df_check.dtypes).split(new_line)))}
374
+
375
+ # Question: {question.strip()}
376
+ # Generate code to answer the question and save result in 'answer' variable
377
+ # If creating a plot, save it with a unique filename and store the filename in 'answer'
378
+ # If returning text/numbers, store the result directly in 'answer'
379
+ ```"""
380
+
381
+ system_prompt = """You are a helpful assistant that generates Python code for data analysis.
382
+
383
+ Rules:
384
+ 1. Always save your final result in a variable called 'answer'
385
+ 2. If creating a plot, save it with plt.savefig() and store the filename in 'answer'
386
+ 3. If returning text/numbers, store the result directly in 'answer'
387
+ 4. Use descriptive variable names and add comments
388
+ 5. Handle potential errors gracefully
389
+ 6. For plots, use unique filenames to avoid conflicts
390
+ """
391
 
392
+ query = f"""{system_prompt}
 
393
 
394
+ Complete the following code to answer the user's question:
 
 
 
 
 
 
395
 
396
+ {template}
397
+ """
398
+
399
+ # Make API call
400
+ if model_name == "gemini-pro":
401
+ response = llm.invoke(query)
402
+ answer = response.content
 
 
 
 
 
 
 
 
 
403
  else:
404
+ response = llm.invoke(query)
405
+ answer = response.content
406
+
407
+ # Extract and execute code
408
  try:
409
+ if "```python" in answer:
410
+ code_part = answer.split("```python")[1].split("```")[0]
411
+ else:
412
+ code_part = answer
413
+
414
+ full_code = f"""
415
+ {template.split("```python")[1].split("```")[0]}
416
+ {code_part}
417
+ """
418
+
419
+ # Execute code in a controlled environment
420
+ local_vars = {}
421
+ global_vars = {
422
+ 'pd': pd,
423
+ 'plt': plt,
424
+ 'os': os,
425
+ 'uuid': __import__('uuid')
426
+ }
427
+
428
+ exec(full_code, global_vars, local_vars)
429
+
430
+ # Get the answer
431
+ if 'answer' in local_vars:
432
+ answer_result = local_vars['answer']
433
+ else:
434
+ answer_result = "No answer variable found in generated code"
435
+
436
+ execution_time = (datetime.now() - start_time).total_seconds()
437
+
438
+ # Determine if output is an image
439
+ is_image = isinstance(answer_result, str) and any(answer_result.endswith(ext) for ext in ['.png', '.jpg', '.jpeg'])
440
+
441
+ # Log successful interaction
442
+ log_interaction(
443
+ user_query=question,
444
+ model_name=model_name,
445
+ response_content=str(answer_result),
446
+ generated_code=full_code,
447
+ execution_time=execution_time,
448
+ error_message=None,
449
+ is_image=is_image
450
  )
451
+
452
+ return {
453
+ "role": "assistant",
454
+ "content": answer_result,
455
+ "gen_code": full_code,
456
+ "ex_code": full_code,
457
+ "last_prompt": question,
458
+ "error": None
459
+ }
460
+
461
+ except Exception as code_error:
462
+ execution_time = (datetime.now() - start_time).total_seconds()
463
+ error_msg = str(code_error)
464
+
465
+ # Log the failed code execution
466
+ log_interaction(
467
+ user_query=question,
468
+ model_name=model_name,
469
+ response_content=f"❌ Error executing generated code: {error_msg}",
470
+ generated_code=full_code if 'full_code' in locals() else "",
471
+ execution_time=execution_time,
472
+ error_message=error_msg,
473
+ is_image=False
 
 
 
 
 
 
 
474
  )
475
+
476
+ return {
477
+ "role": "assistant",
478
+ "content": f"❌ Error executing generated code: {error_msg}",
479
+ "gen_code": full_code if 'full_code' in locals() else "",
480
+ "ex_code": full_code if 'full_code' in locals() else "",
481
+ "last_prompt": question,
482
+ "error": error_msg
483
+ }
484
+
485
+ except Exception as e:
486
+ execution_time = (datetime.now() - start_time).total_seconds()
487
+ error_msg = str(e)
488
+
489
+ # Handle specific API errors
490
+ if "organization_restricted" in error_msg:
491
+ response_content = "❌ API Organization Restricted: Your API key access has been restricted. Please check your Groq API key or try generating a new one."
492
+ log_error_msg = "API access restricted"
493
+ elif "rate_limit" in error_msg.lower():
494
+ response_content = "❌ Rate limit exceeded. Please wait a moment and try again."
495
+ log_error_msg = "Rate limit exceeded"
496
  else:
497
+ response_content = f"❌ Error: {error_msg}"
498
+ log_error_msg = error_msg
499
+
500
+ # Log the failed interaction
501
+ log_interaction(
502
+ user_query=question,
503
+ model_name=model_name,
504
+ response_content=response_content,
505
+ generated_code="",
506
+ execution_time=execution_time,
507
+ error_message=log_error_msg,
508
+ is_image=False
509
  )
510
+
511
+ return {
512
+ "role": "assistant",
513
+ "content": response_content,
514
+ "gen_code": "",
515
+ "ex_code": "",
516
+ "last_prompt": question,
517
+ "error": log_error_msg
518
+ }