GGINCoder commited on
Commit
47d053a
·
verified ·
1 Parent(s): dc2cdeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -10,13 +10,15 @@ from langchain_groq import ChatGroq
10
  import chardet
11
  import pandas as pd
12
  import plotly.graph_objs as go
 
13
 
14
- try:
15
- SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
16
- except NameError:
17
- SCRIPT_DIR = os.getcwd()
18
 
19
- CSV_PATH = os.path.join(SCRIPT_DIR, "evaluations.csv")
 
 
 
20
 
21
  def detect_encoding(file_path):
22
  with open(file_path, 'rb') as file:
@@ -111,10 +113,13 @@ def chat_with_models(file, model_a, model_b, api_key, history_a, history_b, ques
111
  return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
112
 
113
  def load_or_create_df():
114
- if os.path.exists(CSV_PATH):
115
- return pd.read_csv(CSV_PATH)
116
  else:
117
- return pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count'])
 
 
 
118
 
119
  def record_evaluation(df, model_a, model_b, evaluation):
120
  new_row = pd.DataFrame({
@@ -124,9 +129,15 @@ def record_evaluation(df, model_a, model_b, evaluation):
124
  'Count': [1]
125
  })
126
  updated_df = pd.concat([df, new_row], ignore_index=True)
127
- updated_df.to_csv(CSV_PATH, index=False)
 
 
128
  return updated_df
129
 
 
 
 
 
130
  def update_statistics(df):
131
  stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index()
132
 
@@ -148,6 +159,9 @@ def evaluate(df, model_a, model_b, evaluation):
148
  models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"]
149
 
150
  def create_demo():
 
 
 
151
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
152
  gr.Markdown("# Choose two models to compare")
153
 
@@ -197,7 +211,7 @@ def create_demo():
197
  return demo
198
 
199
  if __name__ == "__main__":
200
- print(f"CSV file is located at: {os.path.abspath(CSV_PATH)}")
201
  demo = create_demo()
202
  demo.launch(share=True)
203
  else:
 
10
  import chardet
11
  import pandas as pd
12
  import plotly.graph_objs as go
13
+ import shutil
14
 
15
+ TMP_CSV_PATH = "/tmp/evaluations.csv"
16
+ PERSISTENT_CSV_PATH = "/app/files/evaluations.csv"
 
 
17
 
18
+ def ensure_directory_exists(file_path):
19
+ directory = os.path.dirname(file_path)
20
+ if not os.path.exists(directory):
21
+ os.makedirs(directory)
22
 
23
  def detect_encoding(file_path):
24
  with open(file_path, 'rb') as file:
 
113
  return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
114
 
115
  def load_or_create_df():
116
+ if os.path.exists(PERSISTENT_CSV_PATH):
117
+ return pd.read_csv(PERSISTENT_CSV_PATH)
118
  else:
119
+ df = pd.DataFrame(columns=['Model A', 'Model B', 'Evaluation', 'Count'])
120
+ ensure_directory_exists(PERSISTENT_CSV_PATH)
121
+ df.to_csv(PERSISTENT_CSV_PATH, index=False)
122
+ return df
123
 
124
  def record_evaluation(df, model_a, model_b, evaluation):
125
  new_row = pd.DataFrame({
 
129
  'Count': [1]
130
  })
131
  updated_df = pd.concat([df, new_row], ignore_index=True)
132
+ ensure_directory_exists(TMP_CSV_PATH)
133
+ updated_df.to_csv(TMP_CSV_PATH, index=False)
134
+ copy_to_persistent_storage()
135
  return updated_df
136
 
137
+ def copy_to_persistent_storage():
138
+ ensure_directory_exists(PERSISTENT_CSV_PATH)
139
+ shutil.copy2(TMP_CSV_PATH, PERSISTENT_CSV_PATH)
140
+
141
  def update_statistics(df):
142
  stats = df.groupby(['Model A', 'Model B', 'Evaluation'])['Count'].sum().reset_index()
143
 
 
159
  models = ["llama3-70b-8192", "mixtral-8x7b-32768", "llama3-8b-8192", "gemma-7b-it"]
160
 
161
  def create_demo():
162
+ print(f"Temporary CSV file is located at: {TMP_CSV_PATH}")
163
+ print(f"Persistent CSV file is located at: {PERSISTENT_CSV_PATH}")
164
+
165
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
166
  gr.Markdown("# Choose two models to compare")
167
 
 
211
  return demo
212
 
213
  if __name__ == "__main__":
214
+ print(f"CSV file is located at: {PERSISTENT_CSV_PATH}")
215
  demo = create_demo()
216
  demo.launch(share=True)
217
  else: