Spaces:
Sleeping
Sleeping
spdin
commited on
Commit
·
333cd19
1
Parent(s):
0da7162
initial commit
Browse files- app.py +33 -0
- model.py +47 -0
- prediction.py +48 -0
- training.py +74 -0
- utils.py +12 -0
- validation.py +65 -0
app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
import training
|
| 5 |
+
import validation
|
| 6 |
+
import prediction
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Session initialization
|
| 10 |
+
if "key" not in st.session_state:
|
| 11 |
+
st.session_state["key"] = str(uuid.uuid4()).split("-")[-1]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def training_page():
|
| 15 |
+
training.main()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def validation_page():
|
| 19 |
+
validation.main()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def prediction_page():
|
| 23 |
+
prediction.main()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
page_names_to_funcs = {
|
| 27 |
+
"Training": training_page,
|
| 28 |
+
"Validation": validation_page,
|
| 29 |
+
"Prediction": prediction_page,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys())
|
| 33 |
+
page_names_to_funcs[selected_page]()
|
model.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setfit import SetFitModel, SetFitTrainer
|
| 2 |
+
from sentence_transformers.losses import CosineSimilarityLoss
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Function to create a pipeline for text classification using the trained model
|
| 6 |
+
def create_classifier(model_path):
|
| 7 |
+
classifier = SetFitModel.from_pretrained(
|
| 8 |
+
model_path,
|
| 9 |
+
local_files_only=True,
|
| 10 |
+
)
|
| 11 |
+
return classifier
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_setfit_training(
|
| 15 |
+
session_id, model_id, model_name, train_dataset, batch_size, num_iterations
|
| 16 |
+
):
|
| 17 |
+
|
| 18 |
+
model = SetFitModel.from_pretrained(model_id)
|
| 19 |
+
|
| 20 |
+
# Create trainer
|
| 21 |
+
trainer = SetFitTrainer(
|
| 22 |
+
model=model,
|
| 23 |
+
train_dataset=train_dataset,
|
| 24 |
+
eval_dataset=train_dataset,
|
| 25 |
+
loss_class=CosineSimilarityLoss,
|
| 26 |
+
metric="accuracy",
|
| 27 |
+
batch_size=batch_size,
|
| 28 |
+
num_iterations=num_iterations, # The number of text pairs to generate for contrastive learning
|
| 29 |
+
num_epochs=1, # The number of epochs to use for constrastive learning
|
| 30 |
+
column_mapping={"text": "text", "label": "label"},
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
trainer.train()
|
| 34 |
+
# metrics = trainer.evaluate()
|
| 35 |
+
# accuracy = metrics["accuracy"]
|
| 36 |
+
|
| 37 |
+
print(f"model used: {model_id}")
|
| 38 |
+
print(f"train dataset: {len(train_dataset)} samples")
|
| 39 |
+
# print(f"accuracy: {accuracy}")
|
| 40 |
+
|
| 41 |
+
save_model_path = f"./models/{session_id}/{model_id}_{model_name}"
|
| 42 |
+
|
| 43 |
+
trainer.model._save_pretrained(
|
| 44 |
+
save_directory=f"./models/{session_id}/{model_id}_{model_name}"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return save_model_path
|
prediction.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
st.title("Model Prediction")
|
| 10 |
+
|
| 11 |
+
st.write(f"Session ID: {st.session_state.key}")
|
| 12 |
+
session_id = st.session_state.key
|
| 13 |
+
|
| 14 |
+
if not os.path.isdir(f"models/{session_id}"):
|
| 15 |
+
st.write("Model is not available")
|
| 16 |
+
st.stop()
|
| 17 |
+
|
| 18 |
+
model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
|
| 19 |
+
|
| 20 |
+
models = {
|
| 21 |
+
model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
|
| 22 |
+
for model_name in model_options
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
model_name = st.selectbox("Select a model", options=model_options)
|
| 26 |
+
|
| 27 |
+
# Text input
|
| 28 |
+
text = st.text_area("Enter some text here", height=200)
|
| 29 |
+
|
| 30 |
+
# Prediction button
|
| 31 |
+
if st.button("Predict"):
|
| 32 |
+
|
| 33 |
+
with open(f"{models[model_name]}/label.pkl", "rb") as f:
|
| 34 |
+
label_map = pickle.load(f)
|
| 35 |
+
|
| 36 |
+
classifier = model.create_classifier(models[model_name])
|
| 37 |
+
|
| 38 |
+
prediction = classifier([text])
|
| 39 |
+
|
| 40 |
+
prediction_class = prediction[0].item()
|
| 41 |
+
|
| 42 |
+
confidence_score = classifier.predict_proba([text])[0][prediction_class].item()
|
| 43 |
+
|
| 44 |
+
st.write(
|
| 45 |
+
"The predicted label is:",
|
| 46 |
+
label_map[prediction_class],
|
| 47 |
+
f"{round(confidence_score*100,2)}%",
|
| 48 |
+
)
|
training.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import streamlit as st
|
| 4 |
+
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
|
| 7 |
+
import model
|
| 8 |
+
from utils import check_columns, count_labels
|
| 9 |
+
|
| 10 |
+
# Main function to run the Streamlit app
|
| 11 |
+
def main():
|
| 12 |
+
# Set app title
|
| 13 |
+
st.title("Few Shot Learning Demo using SetFit")
|
| 14 |
+
|
| 15 |
+
# Display the session ID
|
| 16 |
+
st.write(f"Session ID: {st.session_state.key}")
|
| 17 |
+
session_id = st.session_state.key
|
| 18 |
+
|
| 19 |
+
# Create file uploader
|
| 20 |
+
uploaded_file = st.file_uploader("Choose a CSV file to upload", type="csv")
|
| 21 |
+
|
| 22 |
+
# Check if file was uploaded
|
| 23 |
+
if uploaded_file is not None:
|
| 24 |
+
# Read CSV file into pandas DataFrame
|
| 25 |
+
df = pd.read_csv(uploaded_file)
|
| 26 |
+
|
| 27 |
+
# Check if DataFrame has expected columns
|
| 28 |
+
if check_columns(df):
|
| 29 |
+
# Display DataFrame as a table
|
| 30 |
+
st.write(df)
|
| 31 |
+
|
| 32 |
+
# Calculate the number of instances of each label class
|
| 33 |
+
label_counts = count_labels(df)
|
| 34 |
+
st.write(f"Number of instances of each label class: {label_counts}")
|
| 35 |
+
|
| 36 |
+
labels = set(df["label"].tolist())
|
| 37 |
+
label_map = {label: idx for idx, label in enumerate(labels)}
|
| 38 |
+
|
| 39 |
+
df["label"] = df["label"].map(label_map)
|
| 40 |
+
|
| 41 |
+
dataset = Dataset.from_pandas(df)
|
| 42 |
+
|
| 43 |
+
model_name = st.text_input("Input the model name")
|
| 44 |
+
|
| 45 |
+
pretrained_model_options = ["all-MiniLM-L6-v2", "paraphrase-MiniLM-L3-v2"]
|
| 46 |
+
|
| 47 |
+
pretrained_model = st.selectbox(
|
| 48 |
+
"Select a pretrained model", options=pretrained_model_options
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Add Train button
|
| 52 |
+
if st.button("Train"):
|
| 53 |
+
# Train the model
|
| 54 |
+
with st.spinner("Training model..."):
|
| 55 |
+
model_path = model.run_setfit_training(
|
| 56 |
+
session_id,
|
| 57 |
+
pretrained_model,
|
| 58 |
+
model_name,
|
| 59 |
+
dataset,
|
| 60 |
+
1,
|
| 61 |
+
10,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
st.write(f"Model checkpoint saved {model_path.split('/')[-1]}")
|
| 65 |
+
|
| 66 |
+
label_map = {v: k for k, v in label_map.items()}
|
| 67 |
+
with open(f"{model_path}/label.pkl", "wb") as f:
|
| 68 |
+
pickle.dump(label_map, f)
|
| 69 |
+
|
| 70 |
+
st.write("Training Finished")
|
| 71 |
+
st.write("Go to Validation Page")
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
st.error("File must have 'text' and 'label' columns.")
|
utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Function to check if the uploaded file has the expected columns
|
| 2 |
+
def check_columns(df):
|
| 3 |
+
if set(df.columns) == set(["text", "label"]):
|
| 4 |
+
return True
|
| 5 |
+
else:
|
| 6 |
+
return False
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Function to calculate the number of instances of each label class
|
| 10 |
+
def count_labels(df):
|
| 11 |
+
counts = df["label"].value_counts()
|
| 12 |
+
return counts.to_dict()
|
validation.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
import model
|
| 7 |
+
|
| 8 |
+
from utils import check_columns
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Function to validate the trained model with a new uploaded CSV file
|
| 12 |
+
def main():
|
| 13 |
+
|
| 14 |
+
st.title("Model Validation")
|
| 15 |
+
|
| 16 |
+
# Display the session ID
|
| 17 |
+
st.write(f"Session ID: {st.session_state.key}")
|
| 18 |
+
session_id = st.session_state.key
|
| 19 |
+
|
| 20 |
+
if not os.path.isdir(f"models/{session_id}"):
|
| 21 |
+
st.write("Model is not available")
|
| 22 |
+
st.stop()
|
| 23 |
+
|
| 24 |
+
model_options = [model_name for model_name in os.listdir(f"models/{session_id}")]
|
| 25 |
+
|
| 26 |
+
models = {
|
| 27 |
+
model_name: os.path.abspath(os.path.join(f"models/{session_id}", model_name))
|
| 28 |
+
for model_name in model_options
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
model_name = st.selectbox("Select a model", options=model_options)
|
| 32 |
+
|
| 33 |
+
# Create file uploader for validation CSV file
|
| 34 |
+
validation_file = st.file_uploader(
|
| 35 |
+
"Choose a CSV file to validate the model", type="csv"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Check if validation file was uploaded
|
| 39 |
+
if validation_file is not None:
|
| 40 |
+
# Read CSV file into pandas DataFrame
|
| 41 |
+
validation_df = pd.read_csv(validation_file)
|
| 42 |
+
|
| 43 |
+
# Check if DataFrame has expected columns
|
| 44 |
+
if check_columns(validation_df):
|
| 45 |
+
# Display DataFrame as a table
|
| 46 |
+
st.write(validation_df)
|
| 47 |
+
|
| 48 |
+
# Create pipeline for text classification using the trained model
|
| 49 |
+
classifier = model.create_classifier(models[model_name])
|
| 50 |
+
|
| 51 |
+
with open(f"{models[model_name]}/label.pkl", "rb") as f:
|
| 52 |
+
label_map = pickle.load(f)
|
| 53 |
+
|
| 54 |
+
results = classifier(validation_df["text"].tolist())
|
| 55 |
+
|
| 56 |
+
# Predict labels for validation DataFrame
|
| 57 |
+
validation_df["predicted_label"] = [
|
| 58 |
+
label_map[result.item()] for result in results
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# Display validation DataFrame with predicted labels
|
| 62 |
+
st.write("Validation results:")
|
| 63 |
+
st.write(validation_df)
|
| 64 |
+
else:
|
| 65 |
+
st.error("Validation file must have 'text' and 'label' columns.")
|