| import os | |
| import gradio as gr | |
| from langchain.document_loaders import PyMuPDFLoader, TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain_groq import ChatGroq | |
| import chardet | |
| import pandas as pd | |
| import plotly.graph_objs as go | |
| try: | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| except NameError: | |
| SCRIPT_DIR = os.getcwd() | |
| CSV_PATH = os.path.join(SCRIPT_DIR, "evaluations.csv") | |
| def detect_encoding(file_path): | |
| with open(file_path, 'rb') as file: | |
| raw_data = file.read() | |
| return chardet.detect(raw_data)['encoding'] | |
| def setup(file_path): | |
| _, extension = os.path.splitext(file_path) | |
| if extension.lower() == '.pdf': | |
| loader = PyMuPDFLoader(file_path) | |
| elif extension.lower() == '.txt': | |
| encoding = detect_encoding(file_path) | |
| loader = TextLoader(file_path, encoding=encoding) | |
| else: | |
| raise ValueError("Unsupported file type. Please upload a PDF or TXT file.") | |
| document_data = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| all_splits = text_splitter.split_documents(document_data) | |
| persist_directory = 'db' | |
| model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| model_kwargs = {'device': 'cpu'} | |
| embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
| vectordb = Chroma.from_documents(documents=all_splits, embedding=embedding, persist_directory=persist_directory) | |
| sys_prompt = """ | |
| You can only answer questions that are relevant to the content of the document. | |
| If the question that user asks is not relevant to the content of the document, you should say "I don't know". | |
| Do not provide any information that is not in the document. | |
| Do not provide any information by yourself. | |
| You should use sentences from the document to answer the questions. | |
| """ | |
| instruction = """CONTEXT:\n\n {context}\n\nQuestion: {question}""" | |
| prompt_template = f"[INST] <<SYS>>\n{sys_prompt}\n<</SYS>>\n\n{instruction}[/INST]" | |
| llama_prompt = PromptTemplate( | |
| template=prompt_template, input_variables=["context", "question"] | |
| ) | |
| chain_type_kwargs = {"prompt": llama_prompt} | |
| retriever = vectordb.as_retriever() | |
| return retriever, chain_type_kwargs | |
| def setup_qa(file_path, model_name, api_key): | |
| retriever, chain_type_kwargs = setup(file_path) | |
| llm = ChatGroq( | |
| groq_api_key=api_key, | |
| model_name=model_name, | |
| ) | |
| qa = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs=chain_type_kwargs, | |
| verbose=True | |
| ) | |
| return qa | |
| def chat_with_models(file, model_a, model_b, api_key, history_a, history_b, question): | |
| if file is None: | |
| return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)], "" | |
| if not api_key: | |
| return history_a + [("請輸入 API 金鑰。", None)], history_b + [("請輸入 API 金鑰。", None)], "" | |
| file_path = file.name | |
| _, extension = os.path.splitext(file_path) | |
| if extension.lower() not in ['.pdf', '.txt']: | |
| error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。" | |
| return history_a + [(error_message, None)], history_b + [(error_message, None)], "" | |
| try: | |
| qa_a = setup_qa(file_path, model_a, api_key) | |
| qa_b = setup_qa(file_path, model_b, api_key) | |
| response_a = qa_a.invoke(question) | |
| response_b = qa_b.invoke(question) | |
| history_a.append((question, response_a["result"])) | |
| history_b.append((question, response_b["result"])) | |
| return history_a, history_b, "" | |
| except Exception as e: | |
| error_message = f"遇到錯誤:{str(e)}" | |
| return history_a + [(error_message, None)], history_b + [(error_message, None)], "" | |
| def load_or_create_df(): | |
| if os.path.exists(CSV_PATH): | |
| return pd.read_csv(CSV_PATH) | |
| else: | |
| return pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count']) | |
| def record_evaluation(df, model_a, model_b, evaluation): | |
| new_row = pd.DataFrame({ | |
| 'Model A': [model_a], | |
| 'Model B': [model_b], | |
| 'Evaluation': [evaluation], | |
| 'Count': [1] | |
| }) | |
| updated_df = pd.concat([df, new_row], ignore_index=True) | |
| updated_df.to_csv(CSV_PATH, index=False) | |
| return updated_df | |
| def update_statistics(df): | |
| stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index() | |
| fig = go.Figure(data=[ | |
| go.Bar(name=row['Evaluation'], | |
| x=[f"{row['Model A']} vs {row['Model B']}"], | |
| y=[row['Count']]) | |
| for _, row in stats.iterrows() | |
| ]) | |
| fig.update_layout(barmode='group', title='Model Comparison Statistics') | |
| return fig | |
| def evaluate(df, model_a, model_b, evaluation): | |
| updated_df = record_evaluation(df, model_a, model_b, evaluation) | |
| fig = update_statistics(updated_df) | |
| return updated_df, updated_df, fig | |
| models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"] | |
| def create_demo(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Choose three models to compare(With KnowledgeGraphRAG)") | |
| with gr.Row(): | |
| model_a = gr.Dropdown(choices=models, label="Model A", value=models[0]) | |
| model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1]) | |
| api_key = gr.Textbox(label="Enter your Groq API Key", type="password") | |
| file_input = gr.File(label="Upload PDF or TXT file") | |
| with gr.Row(): | |
| chat_a = gr.Chatbot(label="Model A") | |
| chat_b = gr.Chatbot(label="Model B") | |
| question = gr.Textbox(label="Enter your prompt and press ENTER") | |
| send_btn = gr.Button("Send") | |
| with gr.Row(): | |
| a_better = gr.Button("A is better") | |
| b_better = gr.Button("B is better") | |
| tie = gr.Button("Tie") | |
| both_bad = gr.Button("Both are bad") | |
| evaluation_df = gr.State(load_or_create_df()) | |
| evaluation_table = gr.Dataframe(label="Evaluation Records") | |
| statistics_plot = gr.Plot(label="Evaluation Statistics") | |
| send_btn.click( | |
| fn=chat_with_models, | |
| inputs=[file_input, model_a, model_b, api_key, chat_a, chat_b, question], | |
| outputs=[chat_a, chat_b, question], | |
| ) | |
| def create_evaluate_fn(eval_type): | |
| def evaluate_with_type(df, model_a, model_b): | |
| return evaluate(df, model_a, model_b, eval_type) | |
| return evaluate_with_type | |
| for btn, eval_type in [(a_better, "A is better"), (b_better, "B is better"), (tie, "Tie"), (both_bad, "Both are bad")]: | |
| btn.click( | |
| fn=create_evaluate_fn(eval_type), | |
| inputs=[evaluation_df, model_a, model_b], | |
| outputs=[evaluation_df, evaluation_table, statistics_plot], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| print(f"CSV file is located at: {os.path.abspath(CSV_PATH)}") | |
| demo = create_demo() | |
| demo.launch(share=True) | |
| else: | |
| demo = create_demo() | |