lubani's picture
clear btn
1e9fcf2
import os
import time
from llama_index.core import SimpleDirectoryReader
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import VectorStoreIndex
from custom_llm import CustomLLM
import gradio as gr
# import shutil
import tempfile
repo_id = "mistralai/Mistral-7B-Instruct-v0.1"
model_type = 'text-generation'
API_TOKEN = os.getenv('HF_INFER_API')
temp_dir = tempfile.TemporaryDirectory()
embedding_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
llm = CustomLLM(repo_id=repo_id, model_type=model_type, api_token=API_TOKEN)
def add_text(history, text):
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def hasFile(history):
pdf_files = 0
for user_prompt, bot_response in history:
if '.pdf' in user_prompt.lower():
pdf_files += 1
return pdf_files
def modelChanged(history, drop):
history = history + [(f'===> {drop}', None)]
return history, drop
def getEngine(llm):
loader = SimpleDirectoryReader(
input_dir=temp_dir.name,
recursive=True,
required_exts=[".pdf", ".PDF"],
)
# Load files as documents
documents = loader.load_data()
# create an index in the memory
index = VectorStoreIndex.from_documents(
documents,
embed_model=embedding_model,
)
#create query_engine
query_engine = index.as_query_engine(llm=llm)
return query_engine
def copy_pdf(source_path, destination_path):
# Open the source PDF file in binary read mode
with open(source_path, "rb") as source_file:
# Read the entire content of the source file
data = source_file.read()
# Open the destination file in binary write mode
with open(destination_path, "wb") as destination_file:
# Write the copied data to the destination file
destination_file.write(data)
# Print a success message
print(f"PDF copied successfully from {source_path} to {destination_path}")
def add_file(history, file):
pdf_files = hasFile(history)
if pdf_files + 1 >= 4:
history = history + [("%s!!!"%os.path.basename(file), None)]
return history
file_path = os.path.join(temp_dir.name, os.path.basename(file))
# shutil.copyfile(file.name, file_path) # <---Asynchronous
copy_pdf(file.name, file_path)
history = history + [(os.path.basename(file), None)]
return history
def clearClick():
print("clear temp files...")
temp_dir.cleanup()
def format_prompt(message, history, model):
if model is None or 'mistral' in model.lower():
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
elif 'google' in model.lower():
prompt = "<bos>"
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user {user_prompt} <end_of_turn><start_of_turn>model {bot_response}"
prompt += f"<start_of_turn>user {message} <end_of_turn><start_of_turn>model"
else:
prompt = ""
return prompt
def bot(history, model=None):
print("===> model: ", model)
local_llm = llm
if model:
local_llm = CustomLLM(repo_id=model, model_type=model_type, api_token=API_TOKEN)
if len(history) > 0 and len(history[-1]) > 0 and '===>' in history[-1][0]:
new_model = history[-1][0].replace("===>", "")
response = f"You have changed the model to {new_model}"
elif len(history) > 0 and len(history[-1]) > 0 and '.pdf!!!' in history[-1][0].lower():
response = f"Unable to add file. Maximum 3 files allowed."
elif len(history) > 0 and len(history[-1]) > 0 and '.pdf' in history[-1][0]:
response = "You uploaded a PDF file. You can ask questions from the file."
else:
prompt = history[-1][0]
if hasFile(history):
query_engine = getEngine(local_llm)
response = query_engine.query(prompt)
print("Response from file")
else:
response = local_llm.predict(format_prompt(prompt, history, model))
print("Response from Model")
# print(response)
# response = "Thats cool!"
history[-1][1] = ""
for character in str(response):
history[-1][1] += character
# time.sleep(0.05)
yield history
with gr.Blocks() as demo:
gr.Markdown(
"""
<div style="display: grid; justify-content: center;">
<h1>Basic RAG with Huggingface Inference API</h1>
<h4>For best performance start with small PDF files (less than 20 pages). </h4>
</div>
"""
)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
# avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))),
)
with gr.Row():
drop = gr.Dropdown(
[
("Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"),
("Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.2"),
("gemma-7b-it", "google/gemma-7b-it"),
("gemma-2b-it", "google/gemma-2b-it")
],
value="mistralai/Mixtral-8x7B-Instruct-v0.1",
label="Model",
info=""
)
with gr.Row():
txt = gr.Textbox(
scale=4,
show_label=False,
placeholder="Type your question and press enter",
container=False,
)
btn = gr.UploadButton("📁", file_types=[".pdf"])
clear_btn = gr.ClearButton([chatbot, txt])
drop.change(modelChanged, [chatbot, drop], [chatbot, drop], queue=False).then(
bot, [chatbot, drop], chatbot
)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, drop], chatbot, api_name="bot_response"
)
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(
bot, [chatbot, drop], chatbot
)
clear_btn.click(clearClick)
demo.queue()
demo.launch()