Hs-Tool / app.py
kcrl's picture
Upload app.py
b33d909 verified
import streamlit as st
import numpy as np
from datasets import load_dataset, Audio, Dataset
import os
from huggingface_hub import HfApi, login, create_repo, upload_file
import tempfile
import time
import json
from datetime import datetime
# Set page config
st.set_page_config(
page_title="Audio Annotation Tool",
layout="wide",
)
# Set default dataset and space configuration
DEFAULT_DATASET = "kcrl/Hate-Speech"
# Get token from HF Spaces secrets if available, otherwise leave empty
try:
import os
DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN", "")
except:
DEFAULT_HF_TOKEN = ""
# List of authorized annotators
AUTHORIZED_ANNOTATORS = [
"Ayon",
# Add more authorized names here
]
# Save persistent data to a file in the current directory
def get_session_file_path(annotator_name):
# Create a file name based on the annotator
safe_name = annotator_name.replace(" ", "_").lower() if annotator_name else "default"
# Make sure directory exists
os.makedirs("annotation_data", exist_ok=True)
return os.path.join("annotation_data", f"annotation_state_{safe_name}.json")
# Function to save session state
def save_session_state(annotator_name):
if not annotator_name:
return
try:
state_file = get_session_file_path(annotator_name)
state_to_save = {
"current_index": st.session_state.current_index,
"annotations": st.session_state.annotations,
"dataset_name": st.session_state.dataset_name,
"dataset_split": st.session_state.dataset_split,
"annotator_name": st.session_state.annotator_name,
"last_updated": datetime.now().isoformat()
}
with open(state_file, 'w') as f:
# Convert to JSON
json.dump(state_to_save, f, indent=2)
except Exception as e:
st.warning(f"Could not save session state: {str(e)}")
# Function to load session state
def load_session_state(annotator_name):
if not annotator_name:
return False
try:
# Try to load from local file
state_file = get_session_file_path(annotator_name)
if os.path.exists(state_file):
with open(state_file, 'r') as f:
saved_state = json.load(f)
st.session_state.current_index = saved_state.get("current_index", 0)
st.session_state.annotations = saved_state.get("annotations", {})
st.session_state.dataset_name = saved_state.get("dataset_name", DEFAULT_DATASET)
st.session_state.dataset_split = saved_state.get("dataset_split", "train")
return True
except Exception as e:
st.warning(f"Could not load session state: {str(e)}")
return False
return False
# Initialize session state variables if they don't exist
if "current_index" not in st.session_state:
st.session_state.current_index = 0
if "annotations" not in st.session_state:
st.session_state.annotations = {}
if "dataset_initialized" not in st.session_state:
st.session_state.dataset_initialized = False
if "temp_dir" not in st.session_state:
st.session_state.temp_dir = tempfile.mkdtemp()
if "audio_file" not in st.session_state:
st.session_state.audio_file = None
if "dataset_info" not in st.session_state:
st.session_state.dataset_info = None
if "current_sample" not in st.session_state:
st.session_state.current_sample = None
if "dataset_name" not in st.session_state:
st.session_state.dataset_name = DEFAULT_DATASET
if "dataset_split" not in st.session_state:
st.session_state.dataset_split = "train"
if "class_labels" not in st.session_state:
st.session_state.class_labels = ["hate", "non-hate", "discard"]
if "annotator_name" not in st.session_state:
st.session_state.annotator_name = ""
if "hf_token" not in st.session_state:
st.session_state.hf_token = DEFAULT_HF_TOKEN
if "total_samples" not in st.session_state:
st.session_state.total_samples = 0
if "authorized" not in st.session_state:
st.session_state.authorized = False
if "state_loaded" not in st.session_state:
st.session_state.state_loaded = False
# Function to load a single sample
def load_single_sample(dataset_name, split, index):
try:
# Load the dataset with the full split first
dataset = load_dataset(
dataset_name,
split=split
)
# Get just the one example at the specified index
if index < len(dataset):
sample = dataset[index]
# Convert to Audio format if needed
if "audio" in sample and not isinstance(sample["audio"], dict):
# This is for cases where the audio is not already processed
audio_feature = Audio()
sample["audio"] = audio_feature.decode_example(sample["audio"])
return sample
else:
st.error(f"Index {index} is out of range (dataset has {len(dataset)} samples)")
return None
except Exception as e:
st.error(f"Error loading sample {index}: {str(e)}")
return None
# Function to get dataset length
def get_dataset_length(dataset_name, split):
try:
# Get dataset info - load minimal amount to get length
dataset_info = load_dataset(
dataset_name,
split=split
)
return len(dataset_info)
except Exception as e:
st.error(f"Error getting dataset length: {str(e)}")
return 0
# Title and description
st.title("Audio Annotation Tool")
st.markdown("""
This tool allows you to annotate audio files from a Hugging Face dataset.
Your progress is automatically saved and will be restored when you return.
""")
# Sidebar for configuration
with st.sidebar:
st.header("Configuration")
# Dataset configuration (pre-filled with default)
st.session_state.dataset_name = st.text_input("Hugging Face Dataset Name", value=DEFAULT_DATASET, disabled=True)
st.session_state.dataset_split = st.text_input("Dataset Split", value=st.session_state.dataset_split)
# Class labels (hard-coded)
st.text("Labels: hate, non-hate, discard")
# Annotator information
annotator_input = st.text_input("Your Name (Annotator)", value=st.session_state.annotator_name)
# If annotator name changes, try to load their session state
if annotator_input != st.session_state.annotator_name:
st.session_state.annotator_name = annotator_input
if annotator_input and not st.session_state.state_loaded:
# Try to load session state for this annotator
state_loaded = load_session_state(annotator_input)
if state_loaded:
st.success(f"Loaded previous session for {annotator_input}! You can continue from where you left off.")
st.session_state.state_loaded = True
# Auto-initialize if previous session data was found
st.session_state.dataset_initialized = True
# Also load the current sample based on the restored index
try:
# Authenticate with Hugging Face (default token)
login(token=DEFAULT_HF_TOKEN)
# Get dataset size
st.session_state.total_samples = get_dataset_length(
st.session_state.dataset_name,
st.session_state.dataset_split
)
# Load the current sample
st.session_state.current_sample = load_single_sample(
st.session_state.dataset_name,
st.session_state.dataset_split,
st.session_state.current_index
)
st.rerun()
except Exception as e:
st.error(f"Error loading sample: {str(e)}")
# Check if annotator is authorized
if st.session_state.annotator_name and st.session_state.annotator_name not in AUTHORIZED_ANNOTATORS:
st.sidebar.error(f"Sorry, {st.session_state.annotator_name} is not authorized to annotate this dataset.")
st.session_state.authorized = False
elif st.session_state.annotator_name in AUTHORIZED_ANNOTATORS:
st.sidebar.success("Annotator authorized.")
st.session_state.authorized = True
# Hidden HF token - use default from environment
st.session_state.hf_token = DEFAULT_HF_TOKEN
# Initialize dataset button
initialize_button = st.button("Initialize Dataset")
if initialize_button or (st.session_state.state_loaded and not st.session_state.dataset_initialized):
if not st.session_state.authorized:
st.error("You are not authorized to annotate this dataset. Please use an authorized annotator name.")
else:
try:
with st.spinner("Initializing dataset connection..."):
# Authenticate with Hugging Face
login(token=DEFAULT_HF_TOKEN)
# Get the total number of samples without loading the entire dataset
st.session_state.total_samples = get_dataset_length(
st.session_state.dataset_name,
st.session_state.dataset_split
)
if st.session_state.total_samples > 0:
st.session_state.dataset_initialized = True
# Load the current sample based on session state
st.session_state.current_sample = load_single_sample(
st.session_state.dataset_name,
st.session_state.dataset_split,
st.session_state.current_index
)
st.success(f"Dataset initialized! Total samples: {st.session_state.total_samples}")
st.info(f"Starting from sample {st.session_state.current_index + 1}")
else:
st.error("Could not determine the size of the dataset or the dataset is empty.")
except Exception as e:
st.error(f"Error initializing dataset: {str(e)}")
# Main content
if st.session_state.dataset_initialized and st.session_state.current_sample:
# Display dataset info
st.subheader("Dataset Information")
st.write(f"Dataset: {st.session_state.dataset_name}")
st.write(f"Split: {st.session_state.dataset_split}")
st.write(f"Total samples: {st.session_state.total_samples}")
st.write(f"Current sample: {st.session_state.current_index + 1}/{st.session_state.total_samples}")
st.write(f"Annotations completed: {len(st.session_state.annotations)}")
# Display audio player
try:
# Get the current audio sample
audio_sample = st.session_state.current_sample
# Extract audio data
if "audio" in audio_sample and isinstance(audio_sample["audio"], dict):
audio_data = audio_sample["audio"]["array"]
sample_rate = audio_sample["audio"]["sampling_rate"]
# Display metadata if available
st.subheader("Audio Metadata")
metadata_cols = [col for col in audio_sample.keys() if col != "audio"]
if metadata_cols:
metadata_display = {}
for col in metadata_cols:
metadata_display[col] = audio_sample[col]
st.json(metadata_display)
# Create audio player
st.subheader("Audio Player")
st.audio(audio_data, format="audio/wav", sample_rate=sample_rate)
# Annotation interface
st.subheader("Annotation")
# Get the existing annotation if available
# Extract filename more aggressively from metadata
if "file" in audio_sample:
current_audio_id = audio_sample["file"]
elif "filename" in audio_sample:
current_audio_id = audio_sample["filename"]
elif "path" in audio_sample:
current_audio_id = audio_sample["path"]
elif "audio" in audio_sample and "path" in audio_sample["audio"]:
current_audio_id = os.path.basename(audio_sample["audio"]["path"])
else:
# Check for common audio metadata fields that might contain the filename
audio_fields = [field for field in audio_sample.keys() if field != "audio"]
filename_found = False
# Look for any field that might be a filename
for field in audio_fields:
if isinstance(audio_sample[field], str):
if any(ext in audio_sample[field].lower() for ext in ['.wav', '.mp3', '.ogg', '.flac']):
current_audio_id = os.path.basename(audio_sample[field])
filename_found = True
break
# If still not found, try to construct a filename from available fields
if not filename_found and "id" in audio_sample:
current_audio_id = f"{audio_sample['id']}.wav"
elif not filename_found:
current_audio_id = f"audio_{st.session_state.current_index}.wav"
# Ensure we have a clean filename, not a path
current_audio_id = os.path.basename(current_audio_id)
current_annotation = st.session_state.annotations.get(current_audio_id, None)
# Show audio filename that's being used as the ID
st.caption(f"Audio ID: {current_audio_id}")
# Display annotation options
selected_class = st.radio(
"Select Class Label",
options=["hate", "non-hate", "discard"],
index=["hate", "non-hate", "discard"].index(current_annotation) if current_annotation in ["hate", "non-hate", "discard"] else 0
)
additional_notes = st.text_area("Additional Notes")
# Submit annotation
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Previous", disabled=st.session_state.current_index <= 0):
# Save current annotation
st.session_state.annotations[current_audio_id] = selected_class
save_session_state(st.session_state.annotator_name)
# Go to previous audio
st.session_state.current_index = max(0, st.session_state.current_index - 1)
# Load the previous sample
with st.spinner("Loading previous sample..."):
st.session_state.current_sample = load_single_sample(
st.session_state.dataset_name,
st.session_state.dataset_split,
st.session_state.current_index
)
st.rerun()
with col2:
if st.button("Save Annotation"):
if not selected_class:
st.warning("Please provide a class label.")
elif not st.session_state.annotator_name:
st.warning("Please provide your name as the annotator.")
else:
# Save the annotation
st.session_state.annotations[current_audio_id] = selected_class
try:
# Create a mapping to update the dataset
annotations_with_details = {}
for audio_id, label in st.session_state.annotations.items():
annotations_with_details[audio_id] = {
"label": label,
"annotator": st.session_state.annotator_name,
"timestamp": datetime.now().isoformat()
}
# Save session state
save_session_state(st.session_state.annotator_name)
# Log the annotation
st.success(f"Sample {st.session_state.current_index + 1} annotated as '{selected_class}'")
# Save annotations locally
annotations_file = os.path.join(st.session_state.temp_dir, "annotations.json")
with open(annotations_file, "w") as f:
json.dump({
"dataset": st.session_state.dataset_name,
"split": st.session_state.dataset_split,
"annotator": st.session_state.annotator_name,
"annotations": annotations_with_details,
"annotation_date": datetime.now().isoformat()
}, f)
st.success(f"Annotation saved locally. You can download the annotations file at the end of your session.")
# Save annotations to HF Hub directly if checkbox is selected
if st.checkbox("Save to Hugging Face directly"):
try:
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix='.json', delete=False, mode='w') as f:
json.dump({
"dataset": st.session_state.dataset_name,
"split": st.session_state.dataset_split,
"annotator": st.session_state.annotator_name,
"annotations": annotations_with_details,
"annotation_date": datetime.now().isoformat()
}, f)
temp_filepath = f.name
# Create repo if it doesn't exist (will not error if it does)
repo_id = f"{DEFAULT_DATASET}-annotations"
try:
create_repo(repo_id, private=True, token=DEFAULT_HF_TOKEN, exist_ok=True)
except Exception as e:
st.warning(f"Repository already exists, proceeding with upload: {str(e)}")
# Upload the annotations file
annotations_filename = f"annotations_{st.session_state.annotator_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
upload_file(
path_or_fileobj=temp_filepath,
path_in_repo=annotations_filename,
repo_id=repo_id,
token=DEFAULT_HF_TOKEN,
commit_message=f"Annotations by {st.session_state.annotator_name}"
)
st.success(f"Annotations saved to Hugging Face Hub in repository: {repo_id}")
# Remove temp file
os.unlink(temp_filepath)
except Exception as e:
st.error(f"Error saving to Hugging Face Hub: {str(e)}")
st.warning("Annotations saved locally only. Please download them for backup.")
except Exception as e:
st.error(f"Error saving annotation: {str(e)}")
with col3:
if st.button("Next", disabled=st.session_state.current_index >= st.session_state.total_samples - 1):
# Save current annotation
st.session_state.annotations[current_audio_id] = selected_class
save_session_state(st.session_state.annotator_name)
# Go to next audio
st.session_state.current_index = min(st.session_state.total_samples - 1, st.session_state.current_index + 1)
# Load the next sample
with st.spinner("Loading next sample..."):
st.session_state.current_sample = load_single_sample(
st.session_state.dataset_name,
st.session_state.dataset_split,
st.session_state.current_index
)
st.rerun()
# Display progress
st.progress(st.session_state.current_index / max(1, st.session_state.total_samples - 1))
# Option to download annotations
if st.session_state.annotations:
annotations_with_details = {}
for audio_id, label in st.session_state.annotations.items():
annotations_with_details[audio_id] = {
"label": label,
"annotator": st.session_state.annotator_name,
"timestamp": datetime.now().isoformat()
}
annotations_data = {
"dataset": st.session_state.dataset_name,
"split": st.session_state.dataset_split,
"annotator": st.session_state.annotator_name,
"annotations": annotations_with_details,
"annotation_date": datetime.now().isoformat()
}
st.download_button(
"Download Annotations as JSON",
data=json.dumps(annotations_data, indent=2),
file_name=f"audio_annotations_{st.session_state.annotator_name.replace(' ', '_').lower()}.json",
mime="application/json"
)
# Button to jump to a specific sample
st.subheader("Jump to Sample")
col1, col2 = st.columns([3, 1])
with col1:
jump_index = st.number_input("Sample Index", min_value=0, max_value=st.session_state.total_samples-1, value=st.session_state.current_index)
with col2:
if st.button("Jump"):
# Save current annotation before jumping
st.session_state.annotations[current_audio_id] = selected_class
save_session_state(st.session_state.annotator_name)
# Set new index and load sample
st.session_state.current_index = jump_index
st.session_state.current_sample = load_single_sample(
st.session_state.dataset_name,
st.session_state.dataset_split,
st.session_state.current_index
)
st.rerun()
else:
st.error("No audio data found in the current sample. Make sure the dataset has an 'audio' column.")
except Exception as e:
st.error(f"Error displaying audio: {str(e)}")
st.write("Error details:", str(e))
else:
st.info("Please configure and initialize a dataset using the sidebar options.")
# Footer with instructions
st.markdown("---")
st.markdown("""
### Instructions:
1. Enter your name as the annotator (must be on the authorized list)
2. Click "Initialize Dataset" to begin annotation
3. Listen to each audio sample and annotate as:
- hate: Contains hate speech
- non-hate: Does not contain hate speech
- discard: Cannot be categorized or poor audio quality
4. Save your annotations regularly
5. Your progress is automatically saved and will be restored when you return
### Implementation Notes:
- This tool remembers your position and annotations between sessions
- Annotations are saved with original audio filenames as keys
- Only authorized annotators can submit annotations
""")