File size: 7,468 Bytes
e79740b 1810258 e79740b 1810258 e79740b dc2cdeb e79740b dc2cdeb e79740b dc2cdeb e79740b dc2cdeb e79740b dc2cdeb e79740b dc2cdeb e79740b dc2cdeb e79740b dc2cdeb e79740b 1d16c5e e79740b 1d16c5e e79740b 1d16c5e e79740b 7c4b08b e79740b dc2cdeb e79740b dc2cdeb e79740b 1810258 e79740b ef4ae51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import gradio as gr
from langchain.document_loaders import PyMuPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
import chardet
import pandas as pd
import plotly.graph_objs as go
try:
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
SCRIPT_DIR = os.getcwd()
CSV_PATH = os.path.join(SCRIPT_DIR, "evaluations.csv")
def detect_encoding(file_path):
with open(file_path, 'rb') as file:
raw_data = file.read()
return chardet.detect(raw_data)['encoding']
def setup(file_path):
_, extension = os.path.splitext(file_path)
if extension.lower() == '.pdf':
loader = PyMuPDFLoader(file_path)
elif extension.lower() == '.txt':
encoding = detect_encoding(file_path)
loader = TextLoader(file_path, encoding=encoding)
else:
raise ValueError("Unsupported file type. Please upload a PDF or TXT file.")
document_data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(document_data)
persist_directory = 'db'
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_kwargs = {'device': 'cpu'}
embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
vectordb = Chroma.from_documents(documents=all_splits, embedding=embedding, persist_directory=persist_directory)
sys_prompt = """
You can only answer questions that are relevant to the content of the document.
If the question that user asks is not relevant to the content of the document, you should say "I don't know".
Do not provide any information that is not in the document.
Do not provide any information by yourself.
You should use sentences from the document to answer the questions.
"""
instruction = """CONTEXT:\n\n {context}\n\nQuestion: {question}"""
prompt_template = f"[INST] <<SYS>>\n{sys_prompt}\n<</SYS>>\n\n{instruction}[/INST]"
llama_prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": llama_prompt}
retriever = vectordb.as_retriever()
return retriever, chain_type_kwargs
def setup_qa(file_path, model_name, api_key):
retriever, chain_type_kwargs = setup(file_path)
llm = ChatGroq(
groq_api_key=api_key,
model_name=model_name,
)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
verbose=True
)
return qa
def chat_with_models(file, model_a, model_b, api_key, history_a, history_b, question):
if file is None:
return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)], ""
if not api_key:
return history_a + [("請輸入 API 金鑰。", None)], history_b + [("請輸入 API 金鑰。", None)], ""
file_path = file.name
_, extension = os.path.splitext(file_path)
if extension.lower() not in ['.pdf', '.txt']:
error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。"
return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
try:
qa_a = setup_qa(file_path, model_a, api_key)
qa_b = setup_qa(file_path, model_b, api_key)
response_a = qa_a.invoke(question)
response_b = qa_b.invoke(question)
history_a.append((question, response_a["result"]))
history_b.append((question, response_b["result"]))
return history_a, history_b, ""
except Exception as e:
error_message = f"遇到錯誤:{str(e)}"
return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
def load_or_create_df():
if os.path.exists(CSV_PATH):
return pd.read_csv(CSV_PATH)
else:
return pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count'])
def record_evaluation(df, model_a, model_b, evaluation):
new_row = pd.DataFrame({
'Model A': [model_a],
'Model B': [model_b],
'Evaluation': [evaluation],
'Count': [1]
})
updated_df = pd.concat([df, new_row], ignore_index=True)
updated_df.to_csv(CSV_PATH, index=False)
return updated_df
def update_statistics(df):
stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index()
fig = go.Figure(data=[
go.Bar(name=row['Evaluation'],
x=[f"{row['Model A']} vs {row['Model B']}"],
y=[row['Count']])
for _, row in stats.iterrows()
])
fig.update_layout(barmode='group', title='Model Comparison Statistics')
return fig
def evaluate(df, model_a, model_b, evaluation):
updated_df = record_evaluation(df, model_a, model_b, evaluation)
fig = update_statistics(updated_df)
return updated_df, updated_df, fig
models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"]
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Choose three models to compare(With KnowledgeGraphRAG)")
with gr.Row():
model_a = gr.Dropdown(choices=models, label="Model A", value=models[0])
model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1])
api_key = gr.Textbox(label="Enter your Groq API Key", type="password")
file_input = gr.File(label="Upload PDF or TXT file")
with gr.Row():
chat_a = gr.Chatbot(label="Model A")
chat_b = gr.Chatbot(label="Model B")
question = gr.Textbox(label="Enter your prompt and press ENTER")
send_btn = gr.Button("Send")
with gr.Row():
a_better = gr.Button("A is better")
b_better = gr.Button("B is better")
tie = gr.Button("Tie")
both_bad = gr.Button("Both are bad")
evaluation_df = gr.State(load_or_create_df())
evaluation_table = gr.Dataframe(label="Evaluation Records")
statistics_plot = gr.Plot(label="Evaluation Statistics")
send_btn.click(
fn=chat_with_models,
inputs=[file_input, model_a, model_b, api_key, chat_a, chat_b, question],
outputs=[chat_a, chat_b, question],
)
def create_evaluate_fn(eval_type):
def evaluate_with_type(df, model_a, model_b):
return evaluate(df, model_a, model_b, eval_type)
return evaluate_with_type
for btn, eval_type in [(a_better, "A is better"), (b_better, "B is better"), (tie, "Tie"), (both_bad, "Both are bad")]:
btn.click(
fn=create_evaluate_fn(eval_type),
inputs=[evaluation_df, model_a, model_b],
outputs=[evaluation_df, evaluation_table, statistics_plot],
)
return demo
if __name__ == "__main__":
print(f"CSV file is located at: {os.path.abspath(CSV_PATH)}")
demo = create_demo()
demo.launch(share=True)
else:
demo = create_demo()
|