| from langchain.chat_models import init_chat_model |
| from langchain_core.messages import HumanMessage |
| from dotenv import load_dotenv |
| from typing import List |
| from langchain.tools import BaseTool |
| from langchain.agents import initialize_agent, AgentType |
|
|
| _ = load_dotenv() |
|
|
| class LLM: |
| def __init__( |
| self, |
| model: str = "gemini-2.0-flash", |
| model_provider: str = "google_genai", |
| temperature: float = 0.0, |
| max_tokens: int = 1000 |
| ): |
| self.chat_model = init_chat_model( |
| model=model, |
| model_provider=model_provider, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| ) |
|
|
| def generate(self, prompt: str) -> str: |
| message = HumanMessage(content=prompt) |
| response = self.chat_model.invoke([message]) |
| return response.content |
|
|
| def bind_tools(self, tools: List[BaseTool], agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION): |
| """ |
| Bind LangChain tools to this model and return an AgentExecutor. |
| """ |
| return initialize_agent( |
| tools, |
| self.chat_model, |
| agent=agent_type, |
| verbose=False |
| ) |
|
|
| def set_temperature(self, temperature: float): |
| """ |
| Set the temperature for the chat model. |
| """ |
| self.chat_model.temperature = temperature |
|
|
| def set_max_tokens(self, max_tokens: int): |
| """ |
| Set the maximum number of tokens for the chat model. |
| """ |
| self.chat_model.max_tokens = max_tokens |
|
|
|
|
| |
|
|