|
|
|
|
|
|
|
|
import numpy as np |
|
|
import os |
|
|
import sys |
|
|
import pdb |
|
|
|
|
|
|
|
|
def split_segment(prob, sess, spk, start, end, max_dur=2000): |
|
|
dur = end - start |
|
|
if dur <= max_dur: |
|
|
print("SPEAKER {} 1 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>".format(sess, start/100., dur/100., spk)) |
|
|
else: |
|
|
tosplit = int(start+100 + np.argmin(prob[int(start+100):int(end-100)])) |
|
|
split_segment(prob, sess, spk, start, tosplit) |
|
|
split_segment(prob, sess, spk, tosplit, end) |
|
|
|
|
|
|
|
|
prob_array_dir = sys.argv[1] |
|
|
input_rttm = sys.argv[2] |
|
|
prob_array = [os.path.join(prob_array_dir, l) for l in os.listdir(prob_array_dir)] |
|
|
prob_label = {} |
|
|
|
|
|
for p in prob_array: |
|
|
if p.find(".npy") == -1: continue |
|
|
session = os.path.basename(p).split('.')[0] |
|
|
if session.find("CH") != -1 and session.find("S") != -1: |
|
|
sess = session.split("_")[0] |
|
|
elif session.find("CH") != -1 and session.find("S") == -1: |
|
|
sess = "_".join(session.split("_")[:-1]) |
|
|
else: |
|
|
sess = session |
|
|
|
|
|
prob_label[sess] = np.load(os.path.join(p)) |
|
|
IN = open(input_rttm) |
|
|
for l in IN: |
|
|
|
|
|
line = l.split(" ") |
|
|
session = line[1] |
|
|
if line[-2] != "<NA>": |
|
|
spk = line[-2] |
|
|
else: |
|
|
spk = line[-3] |
|
|
|
|
|
start = np.int64(np.float64(line[3]) * 100 ) |
|
|
dur = np.int64(np.float64(line[4]) * 100) |
|
|
end = start + dur |
|
|
if dur <= 3000: |
|
|
print(l.rstrip()) |
|
|
|
|
|
else: |
|
|
split_segment(prob_label[session][int(spk)], session, spk, start, end, max_dur=3000) |
|
|
|