Spaces:
Runtime error
Runtime error
Commit
·
e82ec2b
1
Parent(s):
3b31903
fixed pandas set on copy error
Browse files- preprocessing/preprocess.py +64 -33
preprocessing/preprocess.py
CHANGED
|
@@ -8,22 +8,28 @@ import torchaudio
|
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
return f"{url.split('/')[-1]}.wav"
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
audio_urls = audio_urls.replace(".", np.nan)
|
| 16 |
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
| 17 |
-
valid_audio_mask = audio_urls.apply(
|
|
|
|
|
|
|
| 18 |
return valid_audio_mask
|
| 19 |
|
| 20 |
-
|
|
|
|
| 21 |
"""
|
| 22 |
-
Tests audio urls to ensure that their file exists and the contents is valid.
|
| 23 |
"""
|
| 24 |
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
|
|
|
| 25 |
def is_valid(url):
|
| 26 |
-
valid_url = type(url) == str and "http" in url
|
| 27 |
if not valid_url:
|
| 28 |
return False
|
| 29 |
filename = url_to_filename(url)
|
|
@@ -33,23 +39,29 @@ def validate_audio(audio_urls:pd.Series, audio_dir:str) -> pd.Series:
|
|
| 33 |
w, _ = torchaudio.load(os.path.join(audio_dir, filename))
|
| 34 |
except:
|
| 35 |
return False
|
| 36 |
-
contents_invalid =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
return not contents_invalid
|
| 38 |
-
|
| 39 |
idxs = []
|
| 40 |
validations = []
|
| 41 |
-
for index, url in tqdm(
|
|
|
|
|
|
|
| 42 |
idxs.append(index)
|
| 43 |
validations.append(is_valid(url))
|
| 44 |
|
| 45 |
return pd.Series(validations, index=idxs)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
|
|
|
| 50 |
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
|
| 51 |
-
dance_ratings = dance_ratings.apply(lambda v
|
| 52 |
-
|
|
|
|
| 53 |
new_labels = {}
|
| 54 |
for k, v in labels.items():
|
| 55 |
match = tag_pattern.search(k)
|
|
@@ -57,21 +69,25 @@ def fix_dance_rating_counts(dance_ratings:pd.Series) -> pd.Series:
|
|
| 57 |
new_labels[k] = new_labels.get(k, 0) + v
|
| 58 |
else:
|
| 59 |
k = match[1]
|
| 60 |
-
sign = 1 if match[2] ==
|
| 61 |
scale = int(match[3])
|
| 62 |
new_labels[k] = new_labels.get(k, 0) + v * scale * sign
|
| 63 |
valid = any(v > 0 for v in new_labels.values())
|
| 64 |
return new_labels if valid else np.nan
|
|
|
|
| 65 |
return dance_ratings.apply(fix_labels)
|
| 66 |
|
| 67 |
|
| 68 |
-
def get_unique_labels(dance_labels:pd.Series) -> list:
|
| 69 |
labels = set()
|
| 70 |
for dances in dance_labels:
|
| 71 |
labels |= set(dances)
|
| 72 |
return sorted(labels)
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
"""
|
| 76 |
Turns label dict into probability distribution vector based on each label count.
|
| 77 |
"""
|
|
@@ -80,37 +96,53 @@ def vectorize_label_probs(labels: dict[str,int], unique_labels:np.ndarray) -> np
|
|
| 80 |
item_vec = (unique_labels == k) * v
|
| 81 |
label_vec += item_vec
|
| 82 |
lv_cache = label_vec.copy()
|
| 83 |
-
label_vec[label_vec<0] = 0
|
| 84 |
label_vec /= label_vec.sum()
|
| 85 |
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
| 86 |
return label_vec
|
| 87 |
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
| 89 |
"""
|
| 90 |
Turns label dict into binary label vectors for multi-label classification.
|
| 91 |
"""
|
| 92 |
-
probs = vectorize_label_probs(labels,unique_labels)
|
| 93 |
probs[probs > 0.0] = 1.0
|
| 94 |
return probs
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
| 99 |
if class_list is not None:
|
| 100 |
class_list = set(class_list)
|
| 101 |
sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
|
| 102 |
-
lambda labels
|
| 103 |
-
if not pd.isna(labels)
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
| 106 |
-
vote_mask = sampled_songs["DanceRating"].apply(
|
|
|
|
|
|
|
| 107 |
sampled_songs = sampled_songs[vote_mask]
|
| 108 |
-
labels = sampled_songs["DanceRating"].apply(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
unique_labels = np.array(get_unique_labels(labels))
|
| 110 |
vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
|
| 111 |
-
labels = labels.apply(lambda i
|
| 112 |
|
| 113 |
-
audio_paths = [
|
|
|
|
|
|
|
| 114 |
|
| 115 |
return np.array(audio_paths), np.stack(labels)
|
| 116 |
|
|
@@ -119,12 +151,11 @@ if __name__ == "__main__":
|
|
| 119 |
links = pd.read_csv("data/backup_2.csv", index_col="index")
|
| 120 |
df = pd.read_csv("data/songs.csv")
|
| 121 |
l = links["link"].str.strip()
|
| 122 |
-
l = l.apply(lambda url
|
| 123 |
l = l.dropna()
|
| 124 |
df["Sample"].update(l)
|
| 125 |
-
addna = lambda url
|
| 126 |
df["Sample"] = df["Sample"].apply(addna)
|
| 127 |
-
is_valid = validate_audio(df["Sample"],"data/samples")
|
| 128 |
df["valid"] = is_valid
|
| 129 |
df.to_csv("data/songs_validated.csv")
|
| 130 |
-
|
|
|
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
+
|
| 12 |
+
def url_to_filename(url: str) -> str:
|
| 13 |
return f"{url.split('/')[-1]}.wav"
|
| 14 |
|
| 15 |
+
|
| 16 |
+
def has_valid_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
|
| 17 |
audio_urls = audio_urls.replace(".", np.nan)
|
| 18 |
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
| 19 |
+
valid_audio_mask = audio_urls.apply(
|
| 20 |
+
lambda url: url is not np.nan and url_to_filename(url) in audio_files
|
| 21 |
+
)
|
| 22 |
return valid_audio_mask
|
| 23 |
|
| 24 |
+
|
| 25 |
+
def validate_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series:
|
| 26 |
"""
|
| 27 |
+
Tests audio urls to ensure that their file exists and the contents is valid.
|
| 28 |
"""
|
| 29 |
audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir())
|
| 30 |
+
|
| 31 |
def is_valid(url):
|
| 32 |
+
valid_url = type(url) == str and "http" in url
|
| 33 |
if not valid_url:
|
| 34 |
return False
|
| 35 |
filename = url_to_filename(url)
|
|
|
|
| 39 |
w, _ = torchaudio.load(os.path.join(audio_dir, filename))
|
| 40 |
except:
|
| 41 |
return False
|
| 42 |
+
contents_invalid = (
|
| 43 |
+
torch.any(torch.isnan(w))
|
| 44 |
+
or torch.any(torch.isinf(w))
|
| 45 |
+
or len(torch.unique(w)) <= 2
|
| 46 |
+
)
|
| 47 |
return not contents_invalid
|
| 48 |
+
|
| 49 |
idxs = []
|
| 50 |
validations = []
|
| 51 |
+
for index, url in tqdm(
|
| 52 |
+
audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated"
|
| 53 |
+
):
|
| 54 |
idxs.append(index)
|
| 55 |
validations.append(is_valid(url))
|
| 56 |
|
| 57 |
return pd.Series(validations, index=idxs)
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
|
| 60 |
+
def fix_dance_rating_counts(dance_ratings: pd.Series) -> pd.Series:
|
| 61 |
tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)")
|
| 62 |
+
dance_ratings = dance_ratings.apply(lambda v: json.loads(v.replace("'", '"')))
|
| 63 |
+
|
| 64 |
+
def fix_labels(labels: dict) -> dict | float:
|
| 65 |
new_labels = {}
|
| 66 |
for k, v in labels.items():
|
| 67 |
match = tag_pattern.search(k)
|
|
|
|
| 69 |
new_labels[k] = new_labels.get(k, 0) + v
|
| 70 |
else:
|
| 71 |
k = match[1]
|
| 72 |
+
sign = 1 if match[2] == "+" else -1
|
| 73 |
scale = int(match[3])
|
| 74 |
new_labels[k] = new_labels.get(k, 0) + v * scale * sign
|
| 75 |
valid = any(v > 0 for v in new_labels.values())
|
| 76 |
return new_labels if valid else np.nan
|
| 77 |
+
|
| 78 |
return dance_ratings.apply(fix_labels)
|
| 79 |
|
| 80 |
|
| 81 |
+
def get_unique_labels(dance_labels: pd.Series) -> list:
|
| 82 |
labels = set()
|
| 83 |
for dances in dance_labels:
|
| 84 |
labels |= set(dances)
|
| 85 |
return sorted(labels)
|
| 86 |
|
| 87 |
+
|
| 88 |
+
def vectorize_label_probs(
|
| 89 |
+
labels: dict[str, int], unique_labels: np.ndarray
|
| 90 |
+
) -> np.ndarray:
|
| 91 |
"""
|
| 92 |
Turns label dict into probability distribution vector based on each label count.
|
| 93 |
"""
|
|
|
|
| 96 |
item_vec = (unique_labels == k) * v
|
| 97 |
label_vec += item_vec
|
| 98 |
lv_cache = label_vec.copy()
|
| 99 |
+
label_vec[label_vec < 0] = 0
|
| 100 |
label_vec /= label_vec.sum()
|
| 101 |
assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}"
|
| 102 |
return label_vec
|
| 103 |
|
| 104 |
+
|
| 105 |
+
def vectorize_multi_label(
|
| 106 |
+
labels: dict[str, int], unique_labels: np.ndarray
|
| 107 |
+
) -> np.ndarray:
|
| 108 |
"""
|
| 109 |
Turns label dict into binary label vectors for multi-label classification.
|
| 110 |
"""
|
| 111 |
+
probs = vectorize_label_probs(labels, unique_labels)
|
| 112 |
probs[probs > 0.0] = 1.0
|
| 113 |
return probs
|
| 114 |
|
| 115 |
+
|
| 116 |
+
def get_examples(
|
| 117 |
+
df: pd.DataFrame, audio_dir: str, class_list=None, multi_label=True, min_votes=1
|
| 118 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 119 |
+
sampled_songs = df[has_valid_audio(df["Sample"], audio_dir)].copy(deep=True)
|
| 120 |
sampled_songs["DanceRating"] = fix_dance_rating_counts(sampled_songs["DanceRating"])
|
| 121 |
if class_list is not None:
|
| 122 |
class_list = set(class_list)
|
| 123 |
sampled_songs["DanceRating"] = sampled_songs["DanceRating"].apply(
|
| 124 |
+
lambda labels: {k: v for k, v in labels.items() if k in class_list}
|
| 125 |
+
if not pd.isna(labels)
|
| 126 |
+
and any(label in class_list and amt > 0 for label, amt in labels.items())
|
| 127 |
+
else np.nan
|
| 128 |
+
)
|
| 129 |
sampled_songs = sampled_songs.dropna(subset=["DanceRating"])
|
| 130 |
+
vote_mask = sampled_songs["DanceRating"].apply(
|
| 131 |
+
lambda dances: any(votes >= min_votes for votes in dances.values())
|
| 132 |
+
)
|
| 133 |
sampled_songs = sampled_songs[vote_mask]
|
| 134 |
+
labels = sampled_songs["DanceRating"].apply(
|
| 135 |
+
lambda dances: {
|
| 136 |
+
dance: votes for dance, votes in dances.items() if votes >= min_votes
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
unique_labels = np.array(get_unique_labels(labels))
|
| 140 |
vectorizer = vectorize_multi_label if multi_label else vectorize_label_probs
|
| 141 |
+
labels = labels.apply(lambda i: vectorizer(i, unique_labels))
|
| 142 |
|
| 143 |
+
audio_paths = [
|
| 144 |
+
os.path.join(audio_dir, url_to_filename(url)) for url in sampled_songs["Sample"]
|
| 145 |
+
]
|
| 146 |
|
| 147 |
return np.array(audio_paths), np.stack(labels)
|
| 148 |
|
|
|
|
| 151 |
links = pd.read_csv("data/backup_2.csv", index_col="index")
|
| 152 |
df = pd.read_csv("data/songs.csv")
|
| 153 |
l = links["link"].str.strip()
|
| 154 |
+
l = l.apply(lambda url: url if "http" in url else np.nan)
|
| 155 |
l = l.dropna()
|
| 156 |
df["Sample"].update(l)
|
| 157 |
+
addna = lambda url: url if type(url) == str and "http" in url else np.nan
|
| 158 |
df["Sample"] = df["Sample"].apply(addna)
|
| 159 |
+
is_valid = validate_audio(df["Sample"], "data/samples")
|
| 160 |
df["valid"] = is_valid
|
| 161 |
df.to_csv("data/songs_validated.csv")
|
|
|