GGINCoder commited on
Commit
e79740b
·
verified ·
1 Parent(s): b3b179e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py CHANGED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from langchain.document_loaders import PyMuPDFLoader, TextLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import Chroma
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.chains import RetrievalQA
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_groq import ChatGroq
10
+ import chardet
11
+ import pandas as pd
12
+ import plotly.graph_objs as go
13
+
14
+ os.environ["GROQ_API_KEY"] = "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"
15
+
16
+ try:
17
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ except NameError:
19
+ SCRIPT_DIR = os.getcwd()
20
+
21
+ CSV_PATH = os.path.join(SCRIPT_DIR, "evaluations.csv")
22
+
23
+ def detect_encoding(file_path):
24
+ with open(file_path, 'rb') as file:
25
+ raw_data = file.read()
26
+ return chardet.detect(raw_data)['encoding']
27
+
28
+ def setup(file_path):
29
+ _, extension = os.path.splitext(file_path)
30
+ if extension.lower() == '.pdf':
31
+ loader = PyMuPDFLoader(file_path)
32
+ elif extension.lower() == '.txt':
33
+ encoding = detect_encoding(file_path)
34
+ loader = TextLoader(file_path, encoding=encoding)
35
+ else:
36
+ raise ValueError("Unsupported file type. Please upload a PDF or TXT file.")
37
+
38
+ document_data = loader.load()
39
+
40
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
41
+ all_splits = text_splitter.split_documents(document_data)
42
+
43
+ persist_directory = 'db'
44
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
45
+ model_kwargs = {'device': 'cpu'}
46
+ embedding = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
47
+
48
+ vectordb = Chroma.from_documents(documents=all_splits, embedding=embedding, persist_directory=persist_directory)
49
+
50
+ sys_prompt = """
51
+ You can only answer questions that are relevant to the content of the document.
52
+ If the question that user asks is not relevant to the content of the document, you should say "I don't know".
53
+ Do not provide any information that is not in the document.
54
+ Do not provide any information by yourself.
55
+ You should use sentences from the document to answer the questions.
56
+ """
57
+
58
+ instruction = """CONTEXT:\n\n {context}\n\nQuestion: {question}"""
59
+
60
+ prompt_template = f"[INST] <<SYS>>\n{sys_prompt}\n<</SYS>>\n\n{instruction}[/INST]"
61
+
62
+ llama_prompt = PromptTemplate(
63
+ template=prompt_template, input_variables=["context", "question"]
64
+ )
65
+
66
+ chain_type_kwargs = {"prompt": llama_prompt}
67
+ retriever = vectordb.as_retriever()
68
+
69
+ return retriever, chain_type_kwargs
70
+
71
+ def setup_qa(file_path, model_name):
72
+ retriever, chain_type_kwargs = setup(file_path)
73
+ llm = ChatGroq(
74
+ groq_api_key=os.environ["GROQ_API_KEY"],
75
+ model_name=model_name,
76
+ )
77
+ qa = RetrievalQA.from_chain_type(
78
+ llm=llm,
79
+ chain_type="stuff",
80
+ retriever=retriever,
81
+ chain_type_kwargs=chain_type_kwargs,
82
+ verbose=True
83
+ )
84
+ return qa
85
+
86
+ def chat_with_models(file, model_a, model_b, history_a, history_b, question):
87
+ if file is None:
88
+ return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)]
89
+
90
+ file_path = file.name
91
+ _, extension = os.path.splitext(file_path)
92
+
93
+ if extension.lower() not in ['.pdf', '.txt']:
94
+ error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。"
95
+ return history_a + [(error_message, None)], history_b + [(error_message, None)]
96
+
97
+ try:
98
+ qa_a = setup_qa(file_path, model_a)
99
+ qa_b = setup_qa(file_path, model_b)
100
+
101
+ response_a = qa_a.invoke(question)
102
+ response_b = qa_b.invoke(question)
103
+
104
+ history_a.append((question, response_a["result"]))
105
+ history_b.append((question, response_b["result"]))
106
+
107
+ return history_a, history_b
108
+ except Exception as e:
109
+ error_message = f"遇到錯誤:{str(e)}"
110
+ return history_a + [(error_message, None)], history_b + [(error_message, None)]
111
+
112
+ def load_or_create_df():
113
+ if os.path.exists(CSV_PATH):
114
+ return pd.read_csv(CSV_PATH)
115
+ else:
116
+ return pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count'])
117
+
118
+ def record_evaluation(df, model_a, model_b, evaluation):
119
+ new_row = pd.DataFrame({
120
+ 'Model A': [model_a],
121
+ 'Model B': [model_b],
122
+ 'Evaluation': [evaluation],
123
+ 'Count': [1]
124
+ })
125
+ updated_df = pd.concat([df, new_row], ignore_index=True)
126
+ updated_df.to_csv(CSV_PATH, index=False)
127
+ return updated_df
128
+
129
+ def update_statistics(df):
130
+ stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index()
131
+
132
+ fig = go.Figure(data=[
133
+ go.Bar(name=row['Evaluation'],
134
+ x=[f"{row['Model A']} vs {row['Model B']}"],
135
+ y=[row['Count']])
136
+ for _, row in stats.iterrows()
137
+ ])
138
+
139
+ fig.update_layout(barmode='group', title='Model Comparison Statistics')
140
+ return fig
141
+
142
+ def evaluate(df, model_a, model_b, evaluation):
143
+ updated_df = record_evaluation(df, model_a, model_b, evaluation)
144
+ fig = update_statistics(updated_df)
145
+ return updated_df, updated_df, fig
146
+
147
+ models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"]
148
+
149
+ def create_demo():
150
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
151
+ gr.Markdown("# Choose two models to compare")
152
+
153
+ with gr.Row():
154
+ model_a = gr.Dropdown(choices=models, label="Model A", value=models[0])
155
+ model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1])
156
+
157
+ file_input = gr.File(label="Upload PDF or TXT file")
158
+
159
+ with gr.Row():
160
+ chat_a = gr.Chatbot(label="Model A")
161
+ chat_b = gr.Chatbot(label="Model B")
162
+
163
+ question = gr.Textbox(label="Enter your prompt and press ENTER")
164
+
165
+ send_btn = gr.Button("Send")
166
+
167
+ with gr.Row():
168
+ a_better = gr.Button("A is better")
169
+ b_better = gr.Button("B is better")
170
+ tie = gr.Button("Tie")
171
+ both_bad = gr.Button("Both are bad")
172
+
173
+ evaluation_df = gr.State(load_or_create_df())
174
+ evaluation_table = gr.Dataframe(label="Evaluation Records")
175
+ statistics_plot = gr.Plot(label="Evaluation Statistics")
176
+
177
+ send_btn.click(
178
+ fn=chat_with_models,
179
+ inputs=[file_input, model_a, model_b, chat_a, chat_b, question],
180
+ outputs=[chat_a, chat_b],
181
+ )
182
+
183
+ def create_evaluate_fn(eval_type):
184
+ def evaluate_with_type(df, model_a, model_b):
185
+ return evaluate(df, model_a, model_b, eval_type)
186
+ return evaluate_with_type
187
+
188
+ for btn, eval_type in [(a_better, "A is better"), (b_better, "B is better"), (tie, "Tie"), (both_bad, "Both are bad")]:
189
+ btn.click(
190
+ fn=create_evaluate_fn(eval_type),
191
+ inputs=[evaluation_df, model_a, model_b],
192
+ outputs=[evaluation_df, evaluation_table, statistics_plot],
193
+ )
194
+
195
+ return demo
196
+
197
+ if __name__ == "__main__":
198
+ print(f"CSV file is located at: {os.path.abspath(CSV_PATH)}")
199
+ demo = create_demo()
200
+ demo.launch(share=True)
201
+ else:
202
+ demo = create_demo()