TeLLAgent / tool /rag.py
jinysun's picture
Update tool/rag.py
5c918d5 verified
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 2 20:31:22 2025
@author: BM109X32G-10GPU-02
"""
from langchain.tools import BaseTool
import os
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain import PromptTemplate
from langchain import HuggingFacePipeline
from langchain.base_language import BaseLanguageModel
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from torch import cuda, bfloat16
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
from langchain_openai import OpenAIEmbeddings
class rag(BaseTool):
name: str = "RAG"
description: str= (
"Useful to answer questions that require technical "
"Provide specialized knowledge information for solving Q&A questions"
"Input query , return the response"
)
llm: BaseLanguageModel = None
path : str = None
def __init__(self, path: str = None):
super().__init__( )
self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE")
)
self.path = path
# api keys
def _run(self, query ) -> str:
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE"))
vectorstore=FAISS.load_local(r"rag", embeddings,allow_dangerous_deserialization =True)
template = """
You are an expert chemist and your task is to respond to the question or
solve the problem to the best of your ability.You can only respond with a single "Final Answer" format.
You need to list the key points and explain them in detail and accurately
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
<context>
{context}
</context>
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=template, input_variables=[ "question"])
qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
return_source_documents=False,
chain_type_kwargs={"prompt": prompt},
)
chat_history = []
result = qa_chain.invoke(query)
return result['result']
async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("this tool does not support async")