|
|
import gradio as gr |
|
|
from transformers import ( |
|
|
AutoModelForSequenceClassification, |
|
|
pipeline |
|
|
) |
|
|
from datasets import load_dataset |
|
|
import json |
|
|
import os |
|
|
import subprocess |
|
|
|
|
|
|
|
|
|
|
|
model_ckpt = "echung682/finetuned-emotion-ai-model" |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt) |
|
|
pipe = pipeline(model=model_ckpt) |
|
|
|
|
|
|
|
|
emotion_dataset = load_dataset("echung682/emotion-analysis-tweets") |
|
|
|
|
|
|
|
|
|
|
|
def save_to_repo(): |
|
|
try: |
|
|
subprocess.run(["git", "config", "--global", "user.email", os.environ["GIT_EMAIL"]], check=True) |
|
|
subprocess.run(["git", "config", "--global", "user.name", os.environ["GIT_USER"]], check=True) |
|
|
subprocess.run(["git", "pull", "origin", "main"], check=True) |
|
|
subprocess.run(["git", "add", "feedback_data/flagged.csv"], check=True) |
|
|
subprocess.run(["git", "commit", "-m", "Update flagged data"], check=True) |
|
|
subprocess.run(["git", "push", "origin", "main"], check=True) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Git operation failed: {e}") |
|
|
|
|
|
|
|
|
''' |
|
|
in order to keep track of what the last prompt was that was given human feedback |
|
|
''' |
|
|
def load_state(): |
|
|
try: |
|
|
with open("state.json", "r") as f: |
|
|
state = json.load(f) |
|
|
return state.get("count", 0), state.get("processed_indices", []) |
|
|
except FileNotFoundError: |
|
|
return 0, [] |
|
|
|
|
|
|
|
|
def save_state(count, processed_indices): |
|
|
with open("state.json", "w") as f: |
|
|
json.dump({ |
|
|
"count": count, |
|
|
"processed_indices": processed_indices |
|
|
}, f) |
|
|
|
|
|
def get_next_prompt(): |
|
|
count, processed = load_state() |
|
|
dataset_size = len(emotion_dataset["train"]) |
|
|
|
|
|
|
|
|
if len(processed) >= dataset_size: |
|
|
processed = [] |
|
|
|
|
|
|
|
|
while count in processed: |
|
|
count = (count + 1) % dataset_size |
|
|
|
|
|
processed.append(count) |
|
|
save_state(count, processed) |
|
|
save_state_to_repo() |
|
|
|
|
|
return count |
|
|
|
|
|
def save_state_to_repo(): |
|
|
try: |
|
|
subprocess.run(["git", "pull", "origin", "main"], check=True) |
|
|
subprocess.run(["git", "add", "state.json"], check=True) |
|
|
subprocess.run(["git", "commit", "-m", "Update state"], check=True) |
|
|
subprocess.run(["git", "push", "origin", "main"], check=True) |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Git operation failed: {e}") |
|
|
|
|
|
|
|
|
''' |
|
|
keeping track of the prompt, options, and chosen option |
|
|
then increasing the index number (so it doesn't ask everyone to look at the same ones) |
|
|
writes the new data into the Gradio file |
|
|
pushes the new data and the index number into their respective files to keep track across multiple users |
|
|
''' |
|
|
def updateDataset(prompt, option1, option2, flagged_option): |
|
|
|
|
|
if flagged_option == option1: |
|
|
chosen = option1 |
|
|
rejected = option2 |
|
|
elif flagged_option == option2: |
|
|
chosen = option2 |
|
|
rejected = option1 |
|
|
else: |
|
|
chosen = "" |
|
|
rejected = "" |
|
|
|
|
|
index = get_next_prompt() |
|
|
|
|
|
with open("feedback_data/flagged.csv", "a") as f: |
|
|
f.write(f"{prompt},{chosen},{rejected}\n") |
|
|
|
|
|
|
|
|
save_to_repo() |
|
|
save_state_to_repo() |
|
|
|
|
|
return prompt, chosen, rejected, "Submitted! Please answer another...", index |
|
|
|
|
|
|
|
|
''' |
|
|
finding the correct prompt based on the global index |
|
|
extracting the top two scoring emotions |
|
|
returning these |
|
|
''' |
|
|
def emotion_analysis_data_collection(): |
|
|
index = get_next_prompt() |
|
|
result = pipe(emotion_dataset["train"]["text"][index], top_k = None) |
|
|
score_list = [] |
|
|
emotion_list = [] |
|
|
|
|
|
for emotion in result: |
|
|
emotion_list.append(emotion["label"]) |
|
|
score_list.append(emotion["score"]) |
|
|
|
|
|
emotion_dict = {} |
|
|
for index, value in enumerate(emotion_list): |
|
|
emotion_dict[value] = score_list[index] |
|
|
|
|
|
dictKeys_list = list(emotion_dict.keys()) |
|
|
emotion_highestScore = dictKeys_list[0] |
|
|
emotion_secondHighestScore = dictKeys_list[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return emotion_dataset["train"]["text"][index], emotion_highestScore, emotion_secondHighestScore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
designing the gradio interface |
|
|
has the two options and a Radio object that will keep track of the chosen emotion |
|
|
''' |
|
|
with gr.Blocks() as survey: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Please choose the emotion that best describes the prompt |
|
|
""" |
|
|
) |
|
|
|
|
|
tweet, emotion_highestScore, emotion_secondHighestScore = emotion_analysis_data_collection() |
|
|
sentence = gr.Textbox(tweet, label="Prompt:", interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
emotion1 = gr.Textbox(emotion_highestScore, label="Emotion Choice 1:", interactive=False) |
|
|
emotion2 = gr.Textbox(emotion_secondHighestScore, label="Emotion Choice 2", interactive=False) |
|
|
|
|
|
options = gr.Radio([emotion_highestScore, emotion_secondHighestScore], label="Choose one:") |
|
|
|
|
|
submit_btn = gr.Button("Submit Choice") |
|
|
|
|
|
submit_btn.click(fn=updateDataset, |
|
|
inputs=[sentence, emotion1, emotion2, options], |
|
|
outputs=[gr.Textbox(label="Prompt"), gr.Textbox(label="Chosen Response"), gr.Textbox(label="Rejected Response"), gr.Textbox(label="Confirmation Message"), gr.Textbox(label="Prompt Number")], |
|
|
) |
|
|
survey.launch() |