Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import json, subprocess, librosa | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as colors | |
| from pydub import AudioSegment | |
| # based on captinifeedback.py | |
| # with extra/experimental visual outputs | |
| # for huggingface internal demo | |
| class FeedbackConverter(): | |
| def __init__(self, task_key_path, phone_key_path, lower_bound_100, upper_bound_100, not_scored_value = "TOO SHORT TO SCORE"): | |
| self.task_key_path = task_key_path | |
| self.phone_key_path = phone_key_path | |
| self.lower_bound_100 = lower_bound_100 | |
| self.upper_bound_100 = upper_bound_100 | |
| self.not_scored_value = not_scored_value | |
| self.range_100 = self.upper_bound_100 - self.lower_bound_100 | |
| try: | |
| with open(phone_key_path,'r') as handle: | |
| phone_key = handle.read().splitlines() | |
| phone_key=[l.split('\t') for l in phone_key] | |
| self.phone_key = {phone : float(binary_threshold) for phone, binary_threshold in phone_key} | |
| with open(task_key_path,'r') as handle: | |
| self.task_key = json.load(handle) | |
| except: | |
| raise Exception(f"At least one of the score key files {task_key_path} or {phone_key_path} couldn't be loaded.") | |
| # feedback for task-based scoring ----- | |
| def scale_binary_task(self,raw_score,unit,task_id): | |
| if raw_score == self.not_scored_value: | |
| return 1 # 1: be generous in case not scored | |
| elif raw_score >= self.task_key[task_id][unit]: | |
| return 1 # 1: above threshold, correct pronunciation | |
| else: | |
| return 0 # 0: below threshold, mispronunciation | |
| def b_list_task(self,scores_list,unit,task_id): | |
| return [(label, self.scale_binary_task(score,unit,task_id)) for label,score in scores_list] | |
| # scale score from interval [-1,1] to integers [0,100] | |
| # alternately could replace this with % of phones correct. | |
| def scale_100(self,raw_score): | |
| if raw_score == self.not_scored_value: | |
| return 100 # consider smoothing a different way if this ends up used | |
| elif raw_score <= self.lower_bound_100: | |
| return 0 | |
| elif raw_score >= self.upper_bound_100: | |
| return 100 | |
| else: | |
| rescaled_score = (raw_score - self.lower_bound_100) / self.range_100 | |
| return round(100*rescaled_score) | |
| # heuristics? | |
| # return 1 (correct) for phones/words that are too short to score, | |
| # EXCEPT when a word has score 0 and all phones in that word are too short, | |
| # then return 0 for all of that word's phones. | |
| # also, if a word has score 0 but all individual phones have binary score 1 | |
| # (as a real score, not when they are all too short), | |
| # CHANGE the lowest phone score to 0 so there is some corrective feedback | |
| # TODO turn that part off it if overcorrects native speakers | |
| def wordfix(self,word_phone_scores, word_score, task_id): | |
| if word_score == 1: | |
| return self.b_list_task(word_phone_scores,'phone',task_id) | |
| elif all([sc == self.not_scored_value for ph,sc in word_phone_scores]): | |
| return [(ph, 0) for ph,sc in word_phone_scores ] | |
| else: | |
| bin_scores = self.b_list_task(word_phone_scores,'phone',task_id) | |
| if all([sc == 1 for ph,sc in bin_scores]): | |
| sc_list = [1 if sc == self.not_scored_value | |
| else sc for ph,sc in word_phone_scores] | |
| min_ix = sc_list.index(min(sc_list)) | |
| bin_scores[min_ix] = (bin_scores[min_ix][0],0) | |
| return bin_scores | |
| # feedback for fallback phone scoring ----- | |
| def scale_binary_monophone(self,raw_score,phone_id): | |
| if raw_score == self.not_scored_value: | |
| return 1 | |
| elif raw_score >= self.phone_key[phone_id]: | |
| return 1 | |
| else: | |
| return 0 | |
| def b_list_monophone(self,scores_list): | |
| return [(label, self.scale_binary_monophone(score,label)) for label,score in scores_list] | |
| # score word 0 if any phone is 0, else 1 | |
| # TODO may cause overcorrection of native speakers, | |
| # or confusing inconsistency with 0-100 task score, | |
| # consider word score by average of phone raw scores instead | |
| def b_wordfromphone(self,phone_bins): | |
| return [( word, min([b for p,b in b_phones]) ) for word, b_phones in phone_bins] | |
| # yield score out of 100 as percent of phones correct | |
| def scale_100_monophone(self,phone_bins): | |
| plist = [] | |
| for w, b_phones in phone_bins: | |
| plist += [b for p,b in b_phones] | |
| return int(100*np.nanmean(plist)) | |
| ### -------- some colour printing.... | |
| # sort into 3 colours for printing | |
| # good, mispronounced, unable to score | |
| def phone_3sort_monophone(self,raw_score,phone_id): | |
| if raw_score == self.not_scored_value: | |
| return -1, phone_id | |
| elif raw_score >= self.phone_key[phone_id]: | |
| return 1, phone_id | |
| else: | |
| return 0, phone_id | |
| def phone_3sort_task(self,raw_score,unit,task_id, label): | |
| if raw_score == self.not_scored_value: | |
| return -1, label | |
| elif raw_score >= self.task_key[task_id][unit]: | |
| return 1, label | |
| else: | |
| return 0, label | |
| # put out html | |
| def hc_from_3(self, scoretype, pcontent): | |
| if scoretype == -1: # not scored value | |
| return f"<span style='color:#BBBBBB;'>{pcontent}</span>" | |
| elif scoretype == 1: # correct | |
| return f"<span style='color:#0000FF;'>{pcontent}</span>" | |
| elif scoretype == 0: # wrong | |
| return f"<span style='color:#FF0000;'>{pcontent}</span>" | |
| else: # error | |
| return f"<span>{pcontent}</span>" | |
| def c3_list_monophone(self,scores_list): | |
| #return ''.join([ hc_from_3(self.phone_3sort_monophone(score,label)) for label,score in scores_list ]) | |
| return [self.phone_3sort_monophone(score,label) for label,score in scores_list] | |
| def c3_list_task(self,scores_list,unit,task_id): | |
| #return ''.join([ hc_from_3(self.phone_3sort_task(score,unit,task_id,label)) for label,score in scores_list]) | |
| return [self.phone_3sort_task(score,unit,task_id,label) for label,score in scores_list] | |
| # output is: | |
| # - one score 0-100 for the entire task | |
| # - a score 0/1 for each word | |
| # - a score 0/1 for each phone | |
| def convert(self,word_scores,phone_scores,task_id): | |
| if task_id in self.task_key.keys(): # score with full task model | |
| task_fb = self.scale_100( np.nanmean([sc for wd,sc in word_scores if sc != self.not_scored_value] | |
| or 1) ) | |
| word_fb = self.b_list_task(word_scores,'word',task_id) | |
| phone_fb = [(p_sc[0], self.wordfix(p_sc[1],w_fb[1],task_id) ) | |
| for w_fb, p_sc in zip(word_fb,phone_scores)] | |
| phone_fb2 = [(p_sc[0], self.c3_list_task(p_sc[1],'phone',task_id) ) | |
| for w_fb, p_sc in zip(word_fb,phone_scores)] | |
| else: # score with fallback monophone model | |
| phone_fb = [(p_sc[0], self.b_list_monophone(p_sc[1]) ) for p_sc in phone_scores] | |
| word_fb = self.b_wordfromphone(phone_fb) | |
| task_fb = self.scale_100_monophone(phone_fb) | |
| phone_fb2 = [(p_sc[0], self.c3_list_monophone(p_sc[1]) ) for p_sc in phone_scores] | |
| #return(task_fb, word_fb, phone_fb) | |
| return(task_fb, word_fb, phone_fb2) | |
| # ----------------------- stuff for visual ....... | |
| # TODO 2pass... | |
| def get_pitch_tracks(self,sound_path): | |
| reaper_exec = "/home/user/app/REAPER/build/reaper" | |
| orig_ftype = sound_path.split('.')[-1] | |
| if orig_ftype == '.wav': | |
| wav_path = sound_path | |
| else: | |
| aud_data = AudioSegment.from_file(sound_path, orig_ftype) | |
| curdir = subprocess.run(["pwd"], capture_output=True, text=True) | |
| curdir = curdir.stdout.splitlines()[0] | |
| fname = sound_path.split('/')[-1].replace(orig_ftype,'') | |
| tmp_path = f'{curdir}/{fname}_tmp.wav' | |
| aud_data.export(tmp_path, format="wav") | |
| wav_path = tmp_path | |
| f0_data = subprocess.run([reaper_exec, "-i", wav_path, '-f', '/dev/stdout', '-a'],capture_output=True).stdout | |
| f0_data = f0_data.decode() | |
| f0_data = f0_data.split('EST_Header_End\n')[1].splitlines() | |
| f0_data = [l.split(' ') for l in f0_data] | |
| f0_data = [l for l in f0_data if len(l) == 3] # the last line or 2 lines are other info, different format | |
| f0_data = [ [float(t), float(f)] for t,v,f in f0_data if v=='1'] | |
| if orig_ftype != '.wav': | |
| subprocess.run(["rm", tmp_path]) | |
| return f0_data | |
| # display colour corresponding to a gradient score per phone | |
| def generate_graphic_feedback_blocks(self,phone_scores): | |
| plt.close('all') | |
| phone_scores = [phs for wrd, phs in phone_scores] | |
| phone_scores = [lc for phs in phone_scores for lc in phs] | |
| phone_scores = [[p,np.nan] if c == self.not_scored_value else [p,c] for p,c in phone_scores] | |
| for i in range(len(phone_scores)): | |
| if np.isnan(phone_scores[i][1]): | |
| prev_c = phone_scores[max(i-1,0)][1] # would be nan only in case when i==0 | |
| j = min(i+1,len(phone_scores)-1) | |
| next_c = np.nan | |
| while (np.isnan(next_c) and j < len(phone_scores)): | |
| next_c = phone_scores[j][1] | |
| j += 1 | |
| # at least one of these has value unless the entire stimulus is nan score | |
| phone_scores[i][1] = np.nanmean([prev_c, next_c]) | |
| fig, axs = plt.subplots( figsize=(7, 1.5 )) | |
| #plt.gca().set_aspect(1) | |
| plt.ylim(-1,1.5) | |
| axs.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) | |
| axs.pcolormesh([[c for p,c in phone_scores]], | |
| cmap="rainbow_r", norm=colors.Normalize(vmin=self.lower_bound_100-0.01, vmax=self.upper_bound_100+0.01,clip=True )) | |
| #cmap="plasma") | |
| fig.tight_layout() | |
| for phi, pinfo in enumerate(phone_scores): | |
| plt.text(phi+0.5,-0.5,pinfo[0], ha='center',va='center',color='black',size=12)#, rotation=rt) | |
| return fig | |
| # TODO: | |
| # add subphone / frame level DTW feedback shading? | |
| def generate_graphic_feedback_0(self, sound_path, word_aligns, phone_aligns, phone_feedback, opts): | |
| plt.close('all') | |
| rec_start = word_aligns[0][1] | |
| rec_end = word_aligns[-1][2] | |
| f0_data = self.get_pitch_tracks(sound_path) | |
| if f0_data: | |
| f_max = max([f0 for t,f0 in f0_data]) + 50 | |
| else: | |
| f_max = 400 | |
| fig, axes1 = plt.subplots(figsize=(15,3)) | |
| plt.xlim([rec_start, rec_end]) | |
| axes1.set_ylim([0.0, f_max]) | |
| axes1.get_xaxis().set_visible(False) | |
| for w,s,e in word_aligns: | |
| plt.vlines(s,0,f_max,linewidth=0.5,color='black') | |
| plt.vlines(e,0,f_max,linewidth=0.5,color='dimgrey') | |
| #plt.text( (s+e)/2 - (len(w)*.01), f_max+15, w, fontsize=15) | |
| plt.text( (s+e)/2, f_max+15, w.split('__')[1], fontsize=15, ha="center") | |
| # arrange aligns for graphs... | |
| phone_aligns = [(wrd,phs) for wrd, phs in phone_aligns.items()] | |
| phone_aligns = sorted(phone_aligns, key = lambda x: x[0][:3]) | |
| phone_amends = zip([s for w,s,e in word_aligns], [phs for wrd, phs in phone_aligns]) | |
| phone_aligns = [[(p, s+offset, e+offset) for p,s,e in wphones] for offset, wphones in phone_amends] | |
| phone_aligns = [p for wrps in phone_aligns for p in wrps] | |
| phone_feedback = [phs for wrd, phs in phone_feedback] | |
| phone_feedback = [p for wrps in phone_feedback for p in wrps] | |
| phone_infos = zip(phone_aligns, phone_feedback) | |
| # basic 3way phone to colour key | |
| #cdict = {-1: 'gray', 0: 'red', 1: 'blue'} | |
| cdict = {-1: 'gray', 0: "#E85907", 1: "#26701C"} | |
| for paln, pfbk in phone_infos: | |
| ph_id, s, e = paln | |
| c, p = pfbk | |
| plt.vlines(s,0,f_max,linewidth=0.3,color='cadetblue',linestyle=(0,(10,4))) | |
| plt.vlines(e,0,f_max,linewidth=0.3,color='cadetblue',linestyle=(0,(10,4))) | |
| plt.text( (s+e)/2 - (len(p)*.01), -1*f_max/10, p, fontsize=18, color = cdict[c])#color='teal') | |
| #f0c = "blueviolet" | |
| #enc = 'peachpuff' | |
| f0c = "#88447F" | |
| enc = "#F49098" | |
| axes1.scatter([t for t,f0 in f0_data], [f0 for t,f0 in f0_data], color=f0c) | |
| # add rmse | |
| w, sr = librosa.load(sound_path) | |
| fr_l = 2048 # librosa default | |
| h_l = 512 # default | |
| rmse = librosa.feature.rms(y=w, frame_length = fr_l, hop_length = h_l) | |
| rmse = rmse[0] | |
| # show rms energy, only if opts | |
| if opts: | |
| axes2 = axes1.twinx() | |
| axes2.set_ylim([0.0, 0.5]) | |
| rms_xval = [(h_l*i)/sr for i in range(len(rmse))] | |
| axes2.plot(rms_xval,rmse,color=enc,linewidth=3.5) | |
| fig.tight_layout() | |
| return fig | |