webui / app.py
GGINCoder's picture
Update app.py
ef4ae51 verified
import os
import gradio as gr
from langchain.document_loaders import PyMuPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
import chardet
import pandas as pd
import plotly.graph_objs as go
try:
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
SCRIPT_DIR = os.getcwd()
CSV_PATH = os.path.join(SCRIPT_DIR, "evaluations.csv")
def detect_encoding(file_path):
with open(file_path, 'rb') as file:
raw_data = file.read()
return chardet.detect(raw_data)['encoding']
def setup(file_path):
_, extension = os.path.splitext(file_path)
if extension.lower() == '.pdf':
loader = PyMuPDFLoader(file_path)
elif extension.lower() == '.txt':
encoding = detect_encoding(file_path)
loader = TextLoader(file_path, encoding=encoding)
else:
raise ValueError("Unsupported file type. Please upload a PDF or TXT file.")
document_data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(document_data)
persist_directory = 'db'
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cpu'}
embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
vectordb = Chroma.from_documents(documents=all_splits, embedding=embedding, persist_directory=persist_directory)
sys_prompt = """
You can only answer questions that are relevant to the content of the document.
If the question that user asks is not relevant to the content of the document, you should say "I don't know".
Do not provide any information that is not in the document.
Do not provide any information by yourself.
You should use sentences from the document to answer the questions.
"""
instruction = """CONTEXT:\n\n {context}\n\nQuestion: {question}"""
prompt_template = f"[INST] <<SYS>>\n{sys_prompt}\n<</SYS>>\n\n{instruction}[/INST]"
llama_prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": llama_prompt}
retriever = vectordb.as_retriever()
return retriever, chain_type_kwargs
def setup_qa(file_path, model_name, api_key):
retriever, chain_type_kwargs = setup(file_path)
llm = ChatGroq(
groq_api_key=api_key,
model_name=model_name,
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
verbose=True
)
return qa
def chat_with_models(file, model_a, model_b, api_key, history_a, history_b, question):
if file is None:
return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)], ""
if not api_key:
return history_a + [("請輸入 API 金鑰。", None)], history_b + [("請輸入 API 金鑰。", None)], ""
file_path = file.name
_, extension = os.path.splitext(file_path)
if extension.lower() not in ['.pdf', '.txt']:
error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。"
return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
try:
qa_a = setup_qa(file_path, model_a, api_key)
qa_b = setup_qa(file_path, model_b, api_key)
response_a = qa_a.invoke(question)
response_b = qa_b.invoke(question)
history_a.append((question, response_a["result"]))
history_b.append((question, response_b["result"]))
return history_a, history_b, ""
except Exception as e:
error_message = f"遇到錯誤:{str(e)}"
return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
def load_or_create_df():
if os.path.exists(CSV_PATH):
return pd.read_csv(CSV_PATH)
else:
return pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count'])
def record_evaluation(df, model_a, model_b, evaluation):
new_row = pd.DataFrame({
'Model A': [model_a],
'Model B': [model_b],
'Evaluation': [evaluation],
'Count': [1]
})
updated_df = pd.concat([df, new_row], ignore_index=True)
updated_df.to_csv(CSV_PATH, index=False)
return updated_df
def update_statistics(df):
stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index()
fig = go.Figure(data=[
go.Bar(name=row['Evaluation'],
x=[f"{row['Model A']} vs {row['Model B']}"],
y=[row['Count']])
for _, row in stats.iterrows()
])
fig.update_layout(barmode='group', title='Model Comparison Statistics')
return fig
def evaluate(df, model_a, model_b, evaluation):
updated_df = record_evaluation(df, model_a, model_b, evaluation)
fig = update_statistics(updated_df)
return updated_df, updated_df, fig
models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"]
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Choose three models to compare(With KnowledgeGraphRAG)")
with gr.Row():
model_a = gr.Dropdown(choices=models, label="Model A", value=models[0])
model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1])
api_key = gr.Textbox(label="Enter your Groq API Key", type="password")
file_input = gr.File(label="Upload PDF or TXT file")
with gr.Row():
chat_a = gr.Chatbot(label="Model A")
chat_b = gr.Chatbot(label="Model B")
question = gr.Textbox(label="Enter your prompt and press ENTER")
send_btn = gr.Button("Send")
with gr.Row():
a_better = gr.Button("A is better")
b_better = gr.Button("B is better")
tie = gr.Button("Tie")
both_bad = gr.Button("Both are bad")
evaluation_df = gr.State(load_or_create_df())
evaluation_table = gr.Dataframe(label="Evaluation Records")
statistics_plot = gr.Plot(label="Evaluation Statistics")
send_btn.click(
fn=chat_with_models,
inputs=[file_input, model_a, model_b, api_key, chat_a, chat_b, question],
outputs=[chat_a, chat_b, question],
)
def create_evaluate_fn(eval_type):
def evaluate_with_type(df, model_a, model_b):
return evaluate(df, model_a, model_b, eval_type)
return evaluate_with_type
for btn, eval_type in [(a_better, "A is better"), (b_better, "B is better"), (tie, "Tie"), (both_bad, "Both are bad")]:
btn.click(
fn=create_evaluate_fn(eval_type),
inputs=[evaluation_df, model_a, model_b],
outputs=[evaluation_df, evaluation_table, statistics_plot],
)
return demo
if __name__ == "__main__":
print(f"CSV file is located at: {os.path.abspath(CSV_PATH)}")
demo = create_demo()
demo.launch(share=True)
else:
demo = create_demo()