echung682's picture
Update app.py
b309b24 verified
import gradio as gr
from transformers import (
AutoModelForSequenceClassification,
pipeline
)
from datasets import load_dataset
import json
import os
import subprocess
#importing the model
model_ckpt = "echung682/finetuned-emotion-ai-model"
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
pipe = pipeline(model=model_ckpt)
#importing the dataset (a whole bunch of text)
emotion_dataset = load_dataset("echung682/emotion-analysis-tweets")
#in order to keep the data persistent on HuggingFace repo
#they are saved as secrets in my HuggingFace space because they shouldn't be visible in the code
def save_to_repo():
try:
subprocess.run(["git", "config", "--global", "user.email", os.environ["GIT_EMAIL"]], check=True)
subprocess.run(["git", "config", "--global", "user.name", os.environ["GIT_USER"]], check=True)
subprocess.run(["git", "pull", "origin", "main"], check=True)
subprocess.run(["git", "add", "feedback_data/flagged.csv"], check=True)
subprocess.run(["git", "commit", "-m", "Update flagged data"], check=True)
subprocess.run(["git", "push", "origin", "main"], check=True)
except subprocess.CalledProcessError as e:
print(f"Git operation failed: {e}")
'''
in order to keep track of what the last prompt was that was given human feedback
'''
def load_state():
try:
with open("state.json", "r") as f:
state = json.load(f)
return state.get("count", 0), state.get("processed_indices", [])
except FileNotFoundError:
return 0, []
# Save state to file
def save_state(count, processed_indices):
with open("state.json", "w") as f:
json.dump({
"count": count,
"processed_indices": processed_indices #list of prompts that we already processed
}, f)
def get_next_prompt():
count, processed = load_state()
dataset_size = len(emotion_dataset["train"])
# If we've processed all prompts, start over
if len(processed) >= dataset_size:
processed = []
# Find next unprocessed index
while count in processed: #skipping the prompts that we already processed
count = (count + 1) % dataset_size
processed.append(count)
save_state(count, processed)
save_state_to_repo()
return count
def save_state_to_repo():
try:
subprocess.run(["git", "pull", "origin", "main"], check=True)
subprocess.run(["git", "add", "state.json"], check=True)
subprocess.run(["git", "commit", "-m", "Update state"], check=True)
subprocess.run(["git", "push", "origin", "main"], check=True)
except subprocess.CalledProcessError as e:
print(f"Git operation failed: {e}")
'''
keeping track of the prompt, options, and chosen option
then increasing the index number (so it doesn't ask everyone to look at the same ones)
writes the new data into the Gradio file
pushes the new data and the index number into their respective files to keep track across multiple users
'''
def updateDataset(prompt, option1, option2, flagged_option):
# This function is called when a user clicks a flagging button.
if flagged_option == option1:
chosen = option1
rejected = option2
elif flagged_option == option2:
chosen = option2
rejected = option1
else: # Handle unexpected cases (shouldn't happen with radio buttons)
chosen = ""
rejected = ""
index = get_next_prompt()
with open("feedback_data/flagged.csv", "a") as f:
f.write(f"{prompt},{chosen},{rejected}\n")
# Push the updated file to the repo
save_to_repo() #all of the inputs and outputs for the Gradio interface, that will save to the feedback_data file (and then pushed to HuggingFace repo)
save_state_to_repo()
return prompt, chosen, rejected, "Submitted! Please answer another...", index
'''
finding the correct prompt based on the global index
extracting the top two scoring emotions
returning these
'''
def emotion_analysis_data_collection():
index = get_next_prompt()
result = pipe(emotion_dataset["train"]["text"][index], top_k = None)
score_list = [] #empty list to hold the scores
emotion_list = [] #empty list to hold the emotions
for emotion in result:
emotion_list.append(emotion["label"]) #extracting the emotions from the results
score_list.append(emotion["score"]) #extracing the scores from the results
emotion_dict = {}
for index, value in enumerate(emotion_list):
emotion_dict[value] = score_list[index]
dictKeys_list = list(emotion_dict.keys())
emotion_highestScore = dictKeys_list[0]
emotion_secondHighestScore = dictKeys_list[1]
#print(emotion_highestScore)
#print(emotion_secondHighestScore)
#print(" ")
return emotion_dataset["train"]["text"][index], emotion_highestScore, emotion_secondHighestScore
'''
designing the gradio interface
has the two options and a Radio object that will keep track of the chosen emotion
'''
with gr.Blocks() as survey:
gr.Markdown(
"""
# Please choose the emotion that best describes the prompt
"""
)
tweet, emotion_highestScore, emotion_secondHighestScore = emotion_analysis_data_collection() #calls the function that figures out what the prompt and two highest scoring emotions are
sentence = gr.Textbox(tweet, label="Prompt:", interactive=False)
#print(emotion_highestScore)
#print(emotion_secondHighestScore)
#testOutput = gr.Textbox()
with gr.Row():
emotion1 = gr.Textbox(emotion_highestScore, label="Emotion Choice 1:", interactive=False)
emotion2 = gr.Textbox(emotion_secondHighestScore, label="Emotion Choice 2", interactive=False)
options = gr.Radio([emotion_highestScore, emotion_secondHighestScore], label="Choose one:")
submit_btn = gr.Button("Submit Choice")
submit_btn.click(fn=updateDataset,
inputs=[sentence, emotion1, emotion2, options],
outputs=[gr.Textbox(label="Prompt"), gr.Textbox(label="Chosen Response"), gr.Textbox(label="Rejected Response"), gr.Textbox(label="Confirmation Message"), gr.Textbox(label="Prompt Number")],
)
survey.launch()