|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
from langchain.docstore.document import Document as LangchainDocument |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores.utils import DistanceStrategy |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import Optional, List |
|
|
from tqdm import tqdm |
|
|
from langchain_community.vectorstores import FAISS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pd.set_option("display.max_colwidth", None) |
|
|
|
|
|
|
|
|
with open("iplteams_info.txt", "r") as fp1: |
|
|
content1 = fp1.read() |
|
|
|
|
|
|
|
|
with open("match_summaries_sentences.txt", "r") as fp2: |
|
|
content2 = fp2.read() |
|
|
|
|
|
|
|
|
with open("formatted_playersinfo.txt", "r") as fp3: |
|
|
content3 = fp3.read() |
|
|
|
|
|
|
|
|
combined_content = content1 + "\n\n\n" + content2 + "\n\n\n" + content3 |
|
|
|
|
|
|
|
|
s = combined_content.split("\n\n\n") |
|
|
|
|
|
|
|
|
print(s[0]) |
|
|
print(len(s)) |
|
|
|
|
|
|
|
|
RAW_KNOWLEDGE_BASE = [ |
|
|
LangchainDocument(page_content=doc) |
|
|
for doc in tqdm(s) |
|
|
] |
|
|
|
|
|
|
|
|
MARKDOWN_SEPARATORS = [ |
|
|
"\n#{1,6}", |
|
|
"```\n", |
|
|
"\n\\*\\*\\*+\n", |
|
|
"\n---+\n", |
|
|
"\n__+\n", |
|
|
"\n\n", |
|
|
"\n", |
|
|
" ", |
|
|
"" |
|
|
] |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, |
|
|
chunk_overlap=100, |
|
|
add_start_index=True, |
|
|
strip_whitespace=True, |
|
|
separators=MARKDOWN_SEPARATORS, |
|
|
) |
|
|
|
|
|
docs_processed = [] |
|
|
for doc in RAW_KNOWLEDGE_BASE: |
|
|
docs_processed += text_splitter.split_documents([doc]) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-small") |
|
|
lengths = [len(tokenizer.encode(doc.page_content)) for doc in tqdm(docs_processed)] |
|
|
|
|
|
fig = pd.Series(lengths).hist() |
|
|
fig.set_title("Histogram of Document Lengths") |
|
|
plt.title("Distribution") |
|
|
plt.show() |
|
|
|
|
|
EMBEDDING_MODEL_NAME = "thenlper/gte-small" |
|
|
|
|
|
def split_documents( |
|
|
chunk_size: int, |
|
|
knowledge_base: list[LangchainDocument], |
|
|
tokenizer_name: Optional[str] = EMBEDDING_MODEL_NAME, |
|
|
) -> List[LangchainDocument]: |
|
|
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( |
|
|
AutoTokenizer.from_pretrained(tokenizer_name), |
|
|
chunk_size=chunk_size, |
|
|
chunk_overlap=int(chunk_size / 10), |
|
|
add_start_index=True, |
|
|
strip_whitespace=True, |
|
|
separators=MARKDOWN_SEPARATORS, |
|
|
) |
|
|
docs_processed = [] |
|
|
for doc in knowledge_base: |
|
|
docs_processed += text_splitter.split_documents([doc]) |
|
|
|
|
|
unique_texts = {} |
|
|
docs_processed_unique = [] |
|
|
for doc in docs_processed: |
|
|
if doc.page_content not in unique_texts: |
|
|
unique_texts[doc.page_content] = True |
|
|
docs_processed_unique.append(doc) |
|
|
return docs_processed_unique |
|
|
|
|
|
docs_processed = split_documents(512, RAW_KNOWLEDGE_BASE, tokenizer_name=EMBEDDING_MODEL_NAME) |
|
|
print(len(docs_processed)) |
|
|
print(docs_processed[0:3]) |
|
|
|
|
|
print(torch.cuda.is_available()) |
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings( |
|
|
model_name=EMBEDDING_MODEL_NAME, |
|
|
multi_process=True, |
|
|
model_kwargs={"device": "cuda"}, |
|
|
encode_kwargs={"normalize_embeddings": True}, |
|
|
) |
|
|
|
|
|
KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents( |
|
|
docs_processed, |
|
|
embedding_model, |
|
|
distance_strategy=DistanceStrategy.COSINE, |
|
|
) |
|
|
|
|
|
torch.random.manual_seed(0) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"microsoft/Phi-3-mini-128k-instruct", |
|
|
device_map="cuda", |
|
|
torch_dtype="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") |
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
generation_args = { |
|
|
"max_new_tokens": 500, |
|
|
"return_full_text": False, |
|
|
"temperature": 0.0, |
|
|
"do_sample": False, |
|
|
} |
|
|
|
|
|
prompt_chat=[ |
|
|
{ |
|
|
"role":"system", |
|
|
"content":"""Using the information contained in the context, |
|
|
Give a comprehensive answer to the question. |
|
|
Respond only to the question asked , response should be concise and relevant to the question. |
|
|
provide the number of the source document when relevant. |
|
|
If the answer cannot be deduced from the context, do not give an answer""", |
|
|
}, |
|
|
{ |
|
|
"role":"user", |
|
|
"content":"""Context: |
|
|
{context} |
|
|
--- |
|
|
Now here is the Question you need to answer. |
|
|
Question:{question} |
|
|
""", |
|
|
}, |
|
|
] |
|
|
|
|
|
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template( |
|
|
prompt_chat, tokenize=False, add_generation_prompt=True, |
|
|
) |
|
|
print(RAG_PROMPT_TEMPLATE) |
|
|
|
|
|
u_query = "give the match summary of royal challengers bengaluru and mumbai indians in 2024" |
|
|
retrieved_docs = KNOWLEDGE_VECTOR_DATABASE.similarity_search(query=u_query, k=3) |
|
|
|
|
|
context = retrieved_docs[0].page_content |
|
|
final_prompt = RAG_PROMPT_TEMPLATE.format( |
|
|
question=u_query, context=context |
|
|
) |
|
|
|
|
|
output = pipe(final_prompt, **generation_args) |
|
|
print("YOUR QUESTION:\n", u_query, "\n") |
|
|
print("MICROSOFT 128K ANSWER: \n", output[0]['generated_text']) |
|
|
|
|
|
def handle_query(question): |
|
|
retrieved_docs = KNOWLEDGE_VECTOR_DATABASE.similarity_search(query=question, k=3) |
|
|
context = retrieved_docs[0].page_content |
|
|
final_prompt = RAG_PROMPT_TEMPLATE.format( |
|
|
question=question, context=context |
|
|
) |
|
|
output = pipe(final_prompt, **generation_args) |
|
|
return output[0]['generated_text'] |
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=handle_query, |
|
|
inputs="text", |
|
|
outputs="text", |
|
|
title="IPL Match Summary Generator", |
|
|
description="Get the match summary of IPL teams based on your query.", |
|
|
) |
|
|
|
|
|
interface.launch(sharing=True) |
|
|
|