git_ipl / app.py
ram36's picture
Create app.py
858ea61 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.docstore.document import Document as LangchainDocument
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
import pandas as pd
from tqdm import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
import matplotlib.pyplot as plt
EMBEDDING_MODEL_NAME = "thenlper/gte-small"
MODEL_NAME = "microsoft/Phi-3-mini-128k-instruct"
# Set display option for pandas
pd.set_option("display.max_colwidth", None)
# Load and read the datasets
with open("iplteams_info.txt", "r") as fp1:
content1 = fp1.read()
with open("match_summaries_sentences.txt", "r") as fp2:
content2 = fp2.read()
with open("formatted_playersinfo.txt", "r") as fp3:
content3 = fp3.read()
# Combine contents of both files, separated by three newlines
combined_content = content1 + "\n\n\n" + content2 + "\n\n\n" + content3
# Split the combined content into sections
s = combined_content.split("\n\n\n")
# Create a RAW_KNOWLEDGE_BASE using LangchainDocument
RAW_KNOWLEDGE_BASE = [
LangchainDocument(page_content=doc)
for doc in tqdm(s)
]
# Split and process documents (re-using your previous code)
MARKDOWN_SEPARATORS = ["\n#{1,6}", "```\n", "\n\\*\\*\\*+\n", "\n---+\n", "\n__+\n", "\n\n", "\n", " ", ""]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
add_start_index=True,
strip_whitespace=True,
separators=MARKDOWN_SEPARATORS,
)
docs_processed = []
for doc in RAW_KNOWLEDGE_BASE:
docs_processed += text_splitter.split_documents([doc])
# Tokenizer for checking lengths (optional visualization)
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
lengths = [len(tokenizer.encode(doc.page_content)) for doc in tqdm(docs_processed)]
fig = pd.Series(lengths).hist()
fig.set_title("Histogram of Document Lengths")
plt.title("Distribution")
plt.show()
# Remove duplicate documents
unique_texts = {}
docs_processed_unique = []
for doc in docs_processed:
if doc.page_content not in unique_texts:
unique_texts[doc.page_content] = True
docs_processed_unique.append(doc)
docs_processed = docs_processed_unique
# Load the embedding model
embedding_model = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
multi_process=True,
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": True},
)
# Create the FAISS vector store
KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents(
docs_processed,
embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Define the prompt template
prompt_chat = [
{
"role": "system",
"content": """Using the information contained in the context, Give a comprehensive answer to the question. Respond only to the question asked, response should be concise and relevant to the question. Provide the number of the source document when relevant. If the answer cannot be deduced from the context, do not give an answer.""",
},
{
"role": "user",
"content": """Context: {context} --- Now here is the Question you need to answer. Question: {question}""",
},
]
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(prompt_chat, tokenize=False, add_generation_prompt=True)
# Define the generation arguments
generation_args = {
"max_new_tokens": 500,
"return_full_text": False,
"temperature": 0.0,
"do_sample": False,
}
def query_knowledge_base(u_query):
retrieved_docs = KNOWLEDGE_VECTOR_DATABASE.similarity_search(query=u_query, k=3)
context = retrieved_docs[0].page_content
final_prompt = RAG_PROMPT_TEMPLATE.format(question=u_query, context=context)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = pipe(final_prompt, **generation_args)
return output[0]['generated_text']
if __name__ == "__main__":
u_query = "give the match summary of royal challengers bengaluru and mumbai indians in 2024"
print("YOUR QUESTION:\n", u_query, "\n")
print("MICROSOFT 128K ANSWER: \n", query_knowledge_base(u_query))