|
|
from typing import Optional |
|
|
|
|
|
import langchain |
|
|
from dotenv import load_dotenv |
|
|
from langchain_core.prompts import PromptTemplate |
|
|
from langchain import chains |
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
|
from pydantic import ValidationError |
|
|
from langchain.agents import AgentExecutor |
|
|
from langchain.agents.mrkl.base import ZeroShotAgent |
|
|
from prompts import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, QUESTION_PROMPT1, SUFFIX |
|
|
from tools import make_tools , drug_tools |
|
|
import os |
|
|
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor |
|
|
from langchain_ollama import OllamaLLM |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
|
|
|
from langchain_openai import ChatOpenAI , OpenAI |
|
|
from langchain.agents import load_tools, initialize_agent, AgentType |
|
|
from langchain.llms import OpenAI |
|
|
|
|
|
|
|
|
def convert_to_base64(pil_image): |
|
|
buffered = BytesIO() |
|
|
pil_image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
return img_str |
|
|
|
|
|
def _make_llm(model, temp, api_key, streaming: bool = False): |
|
|
if model.startswith("claude") : |
|
|
llm = OpenAI( |
|
|
temperature=temp, |
|
|
model_name=model, |
|
|
max_tokens = 5000, |
|
|
openai_api_key=api_key, |
|
|
base_url=os.getenv("OPENAI_API_BASE") |
|
|
) |
|
|
elif model.startswith("gpt") or model.startswith("deepseek"): |
|
|
if os.getenv("OPENAI_API_BASE"): |
|
|
llm = ChatOpenAI(model=model, |
|
|
temperature = 0.1, |
|
|
|
|
|
timeout=1000, |
|
|
|
|
|
|
|
|
openai_api_key=api_key,base_url = os.getenv("OPENAI_API_BASE") |
|
|
) |
|
|
else: |
|
|
llm = ChatOpenAI(model=model, |
|
|
temperature = 0.1, |
|
|
|
|
|
timeout=1000, |
|
|
|
|
|
openai_api_key=api_key |
|
|
) |
|
|
|
|
|
elif model.startswith("llama") : |
|
|
llm = OllamaLLM(model=model, |
|
|
temperature = 0.1, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid model name: {model}") |
|
|
return llm |
|
|
|
|
|
|
|
|
class TeLLAgent: |
|
|
def __init__( |
|
|
self, |
|
|
tools=None, |
|
|
model1: str = "deepseek-ai/DeepSeek-R1", |
|
|
model2: str = "gpt-4o-2024-11-20", |
|
|
tools_model="gpt-4o-2024-11-20", |
|
|
temp=0.1, |
|
|
max_iterations=50, |
|
|
verbose=True, |
|
|
streaming: bool = True, |
|
|
openai_api_key= None, |
|
|
api_keys: str = {}, |
|
|
file_path: str= r"...", |
|
|
image_path: str = r"..." |
|
|
): |
|
|
"""Initialize agent.""" |
|
|
self.file_path = file_path |
|
|
self.image_path = image_path |
|
|
load_dotenv() |
|
|
try: |
|
|
self.llm1 = _make_llm(model1, temp, openai_api_key, streaming) |
|
|
self.llm2 = _make_llm(model2, temp, openai_api_key, streaming) |
|
|
except ValidationError: |
|
|
raise ValueError("Invalid OpenAI API key") |
|
|
|
|
|
if tools is None: |
|
|
api_keys["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") |
|
|
tools_llm = _make_llm(tools_model, temp, openai_api_key, streaming) |
|
|
tools = make_tools(tools_llm, api_keys=api_keys, verbose=verbose, image_path = image_path, file_path = file_path) |
|
|
if tools == 'drug': |
|
|
api_keys["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") |
|
|
tools_llm = _make_llm(tools_model, temp, openai_api_key, streaming) |
|
|
tools = drug_tools(tools_llm, api_keys=api_keys, verbose=verbose, image_path = image_path, file_path = file_path) |
|
|
|
|
|
self.agent_executor1 = RetryAgentExecutor.from_agent_and_tools( |
|
|
tools=tools, |
|
|
agent=ChatZeroShotAgent.from_llm_and_tools( |
|
|
self.llm1, |
|
|
tools, |
|
|
suffix=SUFFIX, |
|
|
format_instructions=FORMAT_INSTRUCTIONS, |
|
|
question_prompt=QUESTION_PROMPT1, return_intermediate_steps=True ,handle_parsing_errors=True |
|
|
), |
|
|
verbose=True, |
|
|
max_iterations=1 , return_intermediate_steps=True, handle_parsing_errors=True |
|
|
) |
|
|
self.agent_executor2 = RetryAgentExecutor.from_agent_and_tools( |
|
|
tools=tools, |
|
|
agent=ChatZeroShotAgent.from_llm_and_tools( |
|
|
self.llm2, |
|
|
tools, |
|
|
suffix=SUFFIX, |
|
|
format_instructions=FORMAT_INSTRUCTIONS, |
|
|
question_prompt=QUESTION_PROMPT ,handle_parsing_errors=True |
|
|
), |
|
|
verbose=True, |
|
|
max_iterations=max_iterations ,handle_parsing_errors=True |
|
|
) |
|
|
|
|
|
|
|
|
def run(self, prompt): |
|
|
prompt = prompt + ' ' + str(self.file_path) + ' ' + str(self.image_path) |
|
|
outputs = self.agent_executor1.invoke( {"input": prompt}) |
|
|
try: |
|
|
prompt = str(' ' +outputs["input"]+ ' ' + outputs["output"].split('Action')[0].split('Final Answer')[0].replace("*", "") ) |
|
|
outputs = self.agent_executor2.invoke( {"input":prompt }) |
|
|
except: |
|
|
prompt = str(' ' + outputs["input"] + ' ' + outputs["intermediate_steps"][0][0].log.split('Action')[0].replace("*", "")) |
|
|
outputs = self.agent_executor2.invoke( {"input": prompt}) |
|
|
return outputs['output'] |