import os import gradio as gr from langchain.document_loaders import PyMuPDFLoader, TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain_groq import ChatGroq import chardet import pandas as pd import plotly.graph_objs as go try: SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) except NameError: SCRIPT_DIR = os.getcwd() CSV_PATH = os.path.join(SCRIPT_DIR, "evaluations.csv") def detect_encoding(file_path): with open(file_path, 'rb') as file: raw_data = file.read() return chardet.detect(raw_data)['encoding'] def setup(file_path): _, extension = os.path.splitext(file_path) if extension.lower() == '.pdf': loader = PyMuPDFLoader(file_path) elif extension.lower() == '.txt': encoding = detect_encoding(file_path) loader = TextLoader(file_path, encoding=encoding) else: raise ValueError("Unsupported file type. Please upload a PDF or TXT file.") document_data = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) all_splits = text_splitter.split_documents(document_data) persist_directory = 'db' model_name = "sentence-transformers/all-MiniLM-L6-v2" model_kwargs = {'device': 'cpu'} embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) vectordb = Chroma.from_documents(documents=all_splits, embedding=embedding, persist_directory=persist_directory) sys_prompt = """ You can only answer questions that are relevant to the content of the document. If the question that user asks is not relevant to the content of the document, you should say "I don't know". Do not provide any information that is not in the document. Do not provide any information by yourself. You should use sentences from the document to answer the questions. """ instruction = """CONTEXT:\n\n {context}\n\nQuestion: {question}""" prompt_template = f"[INST] <>\n{sys_prompt}\n<>\n\n{instruction}[/INST]" llama_prompt = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) chain_type_kwargs = {"prompt": llama_prompt} retriever = vectordb.as_retriever() return retriever, chain_type_kwargs def setup_qa(file_path, model_name, api_key): retriever, chain_type_kwargs = setup(file_path) llm = ChatGroq( groq_api_key=api_key, model_name=model_name, ) qa = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, chain_type_kwargs=chain_type_kwargs, verbose=True ) return qa def chat_with_models(file, model_a, model_b, api_key, history_a, history_b, question): if file is None: return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)], "" if not api_key: return history_a + [("請輸入 API 金鑰。", None)], history_b + [("請輸入 API 金鑰。", None)], "" file_path = file.name _, extension = os.path.splitext(file_path) if extension.lower() not in ['.pdf', '.txt']: error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。" return history_a + [(error_message, None)], history_b + [(error_message, None)], "" try: qa_a = setup_qa(file_path, model_a, api_key) qa_b = setup_qa(file_path, model_b, api_key) response_a = qa_a.invoke(question) response_b = qa_b.invoke(question) history_a.append((question, response_a["result"])) history_b.append((question, response_b["result"])) return history_a, history_b, "" except Exception as e: error_message = f"遇到錯誤:{str(e)}" return history_a + [(error_message, None)], history_b + [(error_message, None)], "" def load_or_create_df(): if os.path.exists(CSV_PATH): return pd.read_csv(CSV_PATH) else: return pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count']) def record_evaluation(df, model_a, model_b, evaluation): new_row = pd.DataFrame({ 'Model A': [model_a], 'Model B': [model_b], 'Evaluation': [evaluation], 'Count': [1] }) updated_df = pd.concat([df, new_row], ignore_index=True) updated_df.to_csv(CSV_PATH, index=False) return updated_df def update_statistics(df): stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index() fig = go.Figure(data=[ go.Bar(name=row['Evaluation'], x=[f"{row['Model A']} vs {row['Model B']}"], y=[row['Count']]) for _, row in stats.iterrows() ]) fig.update_layout(barmode='group', title='Model Comparison Statistics') return fig def evaluate(df, model_a, model_b, evaluation): updated_df = record_evaluation(df, model_a, model_b, evaluation) fig = update_statistics(updated_df) return updated_df, updated_df, fig models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"] def create_demo(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Choose two models to compare") with gr.Row(): model_a = gr.Dropdown(choices=models, label="Model A", value=models[0]) model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1]) api_key = gr.Textbox(label="Enter your Groq API Key", type="password") file_input = gr.File(label="Upload PDF or TXT file") with gr.Row(): chat_a = gr.Chatbot(label="Model A") chat_b = gr.Chatbot(label="Model B") question = gr.Textbox(label="Enter your prompt and press ENTER") send_btn = gr.Button("Send") with gr.Row(): a_better = gr.Button("A is better") b_better = gr.Button("B is better") tie = gr.Button("Tie") both_bad = gr.Button("Both are bad") evaluation_df = gr.State(load_or_create_df()) evaluation_table = gr.Dataframe(label="Evaluation Records") statistics_plot = gr.Plot(label="Evaluation Statistics") send_btn.click( fn=chat_with_models, inputs=[file_input, model_a, model_b, api_key, chat_a, chat_b, question], outputs=[chat_a, chat_b, question], ) def create_evaluate_fn(eval_type): def evaluate_with_type(df, model_a, model_b): return evaluate(df, model_a, model_b, eval_type) return evaluate_with_type for btn, eval_type in [(a_better, "A is better"), (b_better, "B is better"), (tie, "Tie"), (both_bad, "Both are bad")]: btn.click( fn=create_evaluate_fn(eval_type), inputs=[evaluation_df, model_a, model_b], outputs=[evaluation_df, evaluation_table, statistics_plot], ) return demo if __name__ == "__main__": print(f"CSV file is located at: {os.path.abspath(CSV_PATH)}") demo = create_demo() demo.launch(share=True) else: demo = create_demo()