|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
from langchain.docstore.document import Document as LangchainDocument |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores.utils import DistanceStrategy |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
EMBEDDING_MODEL_NAME = "thenlper/gte-small" |
|
|
MODEL_NAME = "microsoft/Phi-3-mini-128k-instruct" |
|
|
|
|
|
|
|
|
pd.set_option("display.max_colwidth", None) |
|
|
|
|
|
|
|
|
with open("iplteams_info.txt", "r") as fp1: |
|
|
content1 = fp1.read() |
|
|
|
|
|
with open("match_summaries_sentences.txt", "r") as fp2: |
|
|
content2 = fp2.read() |
|
|
|
|
|
with open("formatted_playersinfo.txt", "r") as fp3: |
|
|
content3 = fp3.read() |
|
|
|
|
|
|
|
|
combined_content = content1 + "\n\n\n" + content2 + "\n\n\n" + content3 |
|
|
|
|
|
|
|
|
s = combined_content.split("\n\n\n") |
|
|
|
|
|
|
|
|
RAW_KNOWLEDGE_BASE = [ |
|
|
LangchainDocument(page_content=doc) |
|
|
for doc in tqdm(s) |
|
|
] |
|
|
|
|
|
|
|
|
MARKDOWN_SEPARATORS = ["\n#{1,6}", "```\n", "\n\\*\\*\\*+\n", "\n---+\n", "\n__+\n", "\n\n", "\n", " ", ""] |
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, |
|
|
chunk_overlap=100, |
|
|
add_start_index=True, |
|
|
strip_whitespace=True, |
|
|
separators=MARKDOWN_SEPARATORS, |
|
|
) |
|
|
docs_processed = [] |
|
|
for doc in RAW_KNOWLEDGE_BASE: |
|
|
docs_processed += text_splitter.split_documents([doc]) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME) |
|
|
lengths = [len(tokenizer.encode(doc.page_content)) for doc in tqdm(docs_processed)] |
|
|
fig = pd.Series(lengths).hist() |
|
|
fig.set_title("Histogram of Document Lengths") |
|
|
plt.title("Distribution") |
|
|
plt.show() |
|
|
|
|
|
|
|
|
unique_texts = {} |
|
|
docs_processed_unique = [] |
|
|
for doc in docs_processed: |
|
|
if doc.page_content not in unique_texts: |
|
|
unique_texts[doc.page_content] = True |
|
|
docs_processed_unique.append(doc) |
|
|
docs_processed = docs_processed_unique |
|
|
|
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings( |
|
|
model_name=EMBEDDING_MODEL_NAME, |
|
|
multi_process=True, |
|
|
model_kwargs={"device": "cuda"}, |
|
|
encode_kwargs={"normalize_embeddings": True}, |
|
|
) |
|
|
|
|
|
|
|
|
KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents( |
|
|
docs_processed, |
|
|
embedding_model, |
|
|
distance_strategy=DistanceStrategy.COSINE, |
|
|
) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
prompt_chat = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": """Using the information contained in the context, Give a comprehensive answer to the question. Respond only to the question asked, response should be concise and relevant to the question. Provide the number of the source document when relevant. If the answer cannot be deduced from the context, do not give an answer.""", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": """Context: {context} --- Now here is the Question you need to answer. Question: {question}""", |
|
|
}, |
|
|
] |
|
|
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(prompt_chat, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
generation_args = { |
|
|
"max_new_tokens": 500, |
|
|
"return_full_text": False, |
|
|
"temperature": 0.0, |
|
|
"do_sample": False, |
|
|
} |
|
|
|
|
|
def query_knowledge_base(u_query): |
|
|
retrieved_docs = KNOWLEDGE_VECTOR_DATABASE.similarity_search(query=u_query, k=3) |
|
|
context = retrieved_docs[0].page_content |
|
|
final_prompt = RAG_PROMPT_TEMPLATE.format(question=u_query, context=context) |
|
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
output = pipe(final_prompt, **generation_args) |
|
|
return output[0]['generated_text'] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
u_query = "give the match summary of royal challengers bengaluru and mumbai indians in 2024" |
|
|
print("YOUR QUESTION:\n", u_query, "\n") |
|
|
print("MICROSOFT 128K ANSWER: \n", query_knowledge_base(u_query)) |
|
|
|