Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import os | |
| from huggingface_hub import HfFileSystem | |
| REPO_ID = "nsourlos/draco_streamlit" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| data_path='data.csv' | |
| # Load the CSV file | |
| def load_data(file): | |
| df = pd.read_csv(file, index_col='id') | |
| return df | |
| # Save the CSV file | |
| def save_data(df, filename): | |
| df.to_csv(filename) | |
| # Function to calculate accuracy for each unique text attribute | |
| def calculate_accuracy(df): | |
| accuracy_dict = {} | |
| grouped = df.groupby('text')['label'] | |
| for text, labels in grouped: | |
| accuracy = labels.mean() # Accuracy is the mean of the label values | |
| accuracy_dict[text] = accuracy | |
| return accuracy_dict | |
| # Initialize session state variables | |
| if 'data' not in st.session_state: | |
| st.session_state.data = None | |
| if 'new_rows' not in st.session_state: | |
| st.session_state.new_rows = [] | |
| if 'file_path' not in st.session_state: | |
| st.session_state.file_path = None | |
| if 'add_row_clicked' not in st.session_state: | |
| st.session_state.add_row_clicked = False | |
| if 'rerun_count' not in st.session_state: | |
| st.session_state.rerun_count = 0 | |
| if 'finished' not in st.session_state: | |
| st.session_state.finished = False | |
| # Function to add new row | |
| def add_row(new_text, new_label): | |
| new_id = st.session_state['data'].index.max() + 1 if not st.session_state['data'].empty else 0 | |
| new_row = {'id': new_id, 'text': new_text, 'label': new_label, 'checked': False} | |
| st.session_state.new_rows.append(new_row) | |
| updated_data=pd.concat([st.session_state.data, pd.DataFrame([new_row]).set_index('id')]) | |
| file_path=st.session_state.file_path | |
| save_data(updated_data, file_path) | |
| st.session_state.data=load_data(file_path) | |
| st.session_state.add_row_clicked = False # Reset the add row state | |
| st.session_state.rerun_count += 1 | |
| st.rerun() | |
| # Streamlit app | |
| st.title("Interactive DataFrame Editor") | |
| # uploaded_file = st.file_uploader("Upload your CSV file", type="csv") | |
| uploaded_file = data_path#'data.csv' | |
| if uploaded_file is not None: | |
| st.session_state.file_path = uploaded_file#.name | |
| if st.session_state.rerun_count==0: | |
| st.session_state.data = load_data(uploaded_file) | |
| file_loaded=uploaded_file#.name | |
| st.subheader("DataFrame") | |
| if st.session_state.data is not None: | |
| # Display non-editable columns | |
| edited_data = st.data_editor(st.session_state.data) | |
| if edited_data is not None: | |
| st.session_state.data = edited_data | |
| save_data(st.session_state.data, st.session_state.file_path) | |
| if st.button("Add Row"): | |
| st.session_state.add_row_clicked = True | |
| if st.session_state.add_row_clicked: | |
| # Inputs for adding new row | |
| new_text = st.text_input("Enter model name for new row:") | |
| new_label = st.selectbox("Select label for new row:", options=[0, 1]) | |
| if st.button("Confirm Add Row"): | |
| add_row(new_text, new_label) | |
| # Calculate accuracy | |
| accuracy_dict = calculate_accuracy(st.session_state.data) | |
| # Create scatter plot | |
| texts = list(accuracy_dict.keys()) | |
| accuracies = list(accuracy_dict.values()) | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| ax.scatter(texts, accuracies) | |
| ax.set_xlabel('Text') | |
| ax.set_ylabel('Accuracy') | |
| ax.set_title('Accuracy of Labels for Each Text Attribute') | |
| plt.xticks(rotation=90) # Rotate x-axis labels for better readability | |
| st.subheader("Leaderboard") | |
| st.pyplot(fig) | |
| # Button to finish and reset session state | |
| if st.button('Finish'): | |
| st.success('Saving.... Space will restart soon....') | |
| st.session_state.finished = True | |
| fs = HfFileSystem(token=HF_TOKEN.replace("\"","")) | |
| with fs.open('spaces/nsourlos/draco_streamlit/data.csv', 'w') as f: | |
| f.write(st.session_state.data.to_csv()) | |
| else: | |
| st.write("Please upload a CSV file to get started.") |