NeoPy commited on
Commit
b1cded8
·
verified ·
1 Parent(s): 1ff2a9f
tools/utils/gdown.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import json
5
+ import tqdm
6
+ import codecs
7
+ import tempfile
8
+ import requests
9
+
10
+ from urllib.parse import urlparse, parse_qs, unquote
11
+
12
+ sys.path.append(os.getcwd())
13
+
14
+ from main.app.variables import translations
15
+
16
+ def parse_url(url):
17
+ parsed = urlparse(url)
18
+ is_download_link = parsed.path.endswith("/uc")
19
+
20
+ if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
21
+ file_id = parse_qs(parsed.query).get("id", [None])[0]
22
+
23
+ if file_id is None:
24
+ for pattern in (
25
+ r"^/file/d/(.*?)/(edit|view)$",
26
+ r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$",
27
+ r"^/document/d/(.*?)/(edit|htmlview|view)$",
28
+ r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$",
29
+ r"^/presentation/d/(.*?)/(edit|htmlview|view)$",
30
+ r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$",
31
+ r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$",
32
+ r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"
33
+ ):
34
+ match = re.match(pattern, parsed.path)
35
+
36
+ if match:
37
+ file_id = match.group(1)
38
+ break
39
+
40
+ return file_id, is_download_link
41
+
42
+ def get_url_from_gdrive_confirmation(contents):
43
+ for pattern in (
44
+ r'href="(\/uc\?export=download[^"]+)',
45
+ r'href="/open\?id=([^"]+)"',
46
+ r'"downloadUrl":"([^"]+)'
47
+ ):
48
+ match = re.search(pattern, contents)
49
+
50
+ if match:
51
+ url = match.group(1)
52
+
53
+ if pattern == r'href="/open\?id=([^"]+)"':
54
+ url = (
55
+ codecs.decode("uggcf://qevir.hfrepbagrag.tbbtyr.pbz/qbjaybnq?vq=", "rot13") +
56
+ url +
57
+ "&confirm=t&uuid=" +
58
+ re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1)
59
+ )
60
+ elif pattern == r'"downloadUrl":"([^"]+)':
61
+ url = (
62
+ url.replace("\\u003d", "=").replace("\\u0026", "&")
63
+ )
64
+ else:
65
+ url = (
66
+ codecs.decode("uggcf://qbpf.tbbtyr.pbz", "rot13") +
67
+ url.replace("&", "&")
68
+ )
69
+
70
+ return url
71
+
72
+ match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
73
+ if match: raise Exception(match.group(1))
74
+
75
+ raise Exception(translations["gdown_error"])
76
+
77
+ def _get_session(use_cookies, return_cookies_file=False):
78
+ sess = requests.session()
79
+ sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
80
+ cookies_file = os.path.join(os.path.expanduser("~"), ".cache/gdown/cookies.json")
81
+
82
+ if os.path.exists(cookies_file) and use_cookies:
83
+ with open(cookies_file) as f:
84
+ for k, v in json.load(f):
85
+ sess.cookies[k] = v
86
+
87
+ return (sess, cookies_file) if return_cookies_file else sess
88
+
89
+ def gdown_download(url=None, output=None):
90
+ file_id = None
91
+ if url is None: raise ValueError(translations["gdown_value_error"])
92
+
93
+ if "/file/d/" in url:
94
+ file_id = url.split("/d/")[1].split("/")[0]
95
+ elif "open?id=" in url:
96
+ file_id = url.split("open?id=")[1].split("/")[0]
97
+ elif "/download?id=" in url:
98
+ file_id = url.split("/download?id=")[1].split("&")[0]
99
+
100
+ if file_id:
101
+ url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{file_id}"
102
+ url_origin = url
103
+
104
+ sess, cookies_file = _get_session(use_cookies=True, return_cookies_file=True)
105
+ gdrive_file_id, is_gdrive_download_link = parse_url(url)
106
+
107
+ if gdrive_file_id:
108
+ url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{gdrive_file_id}"
109
+ url_origin = url
110
+ is_gdrive_download_link = True
111
+
112
+ while 1:
113
+ res = sess.get(url, stream=True, verify=True)
114
+ if url == url_origin and res.status_code == 500:
115
+ url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/bcra?vq=', 'rot13')}{gdrive_file_id}"
116
+ continue
117
+
118
+ os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
119
+ with open(cookies_file, "w") as f:
120
+ json.dump(
121
+ [(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")],
122
+ f,
123
+ indent=2
124
+ )
125
+
126
+ if ("Content-Disposition" in res.headers) or (not (gdrive_file_id and is_gdrive_download_link)): break
127
+
128
+ try:
129
+ url = get_url_from_gdrive_confirmation(res.text)
130
+ except Exception as e:
131
+ raise Exception(e)
132
+
133
+ if gdrive_file_id and is_gdrive_download_link:
134
+ content_disposition = unquote(res.headers["Content-Disposition"])
135
+
136
+ filename_from_url = (
137
+ re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)
138
+ ).group(1).replace(os.path.sep, "_")
139
+ else:
140
+ filename_from_url = os.path.basename(url)
141
+
142
+ output = os.path.join(output or ".", filename_from_url)
143
+ tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=os.path.basename(output), dir=os.path.dirname(output))
144
+ f = open(tmp_file, "ab")
145
+
146
+ if tmp_file is not None and f.tell() != 0:
147
+ res = sess.get(
148
+ url,
149
+ headers={
150
+ "Range": f"bytes={f.tell()}-"
151
+ },
152
+ stream=True,
153
+ verify=True
154
+ )
155
+
156
+ try:
157
+ with tqdm.tqdm(
158
+ desc=os.path.basename(output),
159
+ total=int(res.headers.get("Content-Length", 0)),
160
+ ncols=100,
161
+ unit="byte"
162
+ ) as pbar:
163
+ for chunk in res.iter_content(chunk_size=512 * 1024):
164
+ f.write(chunk)
165
+ pbar.update(len(chunk))
166
+
167
+ pbar.close()
168
+ if tmp_file: f.close()
169
+ finally:
170
+ os.rename(tmp_file, output)
171
+ sess.close()
172
+
173
+ return output
174
+
175
+ return None
tools/utils/huggingface.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tqdm
3
+ import requests
4
+
5
+ try:
6
+ import wget
7
+ except:
8
+ wget = None
9
+
10
+ def HF_download_file(url, output_path=None):
11
+ url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
12
+ output_path = (
13
+ os.path.basename(url)
14
+ ) if output_path is None else (
15
+ os.path.join(output_path, os.path.basename(url)) if os.path.isdir(output_path) else output_path
16
+ )
17
+
18
+ if wget != None:
19
+ wget.download(
20
+ url,
21
+ out=output_path
22
+ )
23
+ else:
24
+ response = requests.get(url, stream=True, timeout=300)
25
+
26
+ if response.status_code == 200:
27
+ progress_bar = tqdm.tqdm(
28
+ total=int(response.headers.get("content-length", 0)),
29
+ desc=os.path.basename(url),
30
+ ncols=100,
31
+ unit="byte",
32
+ leave=False
33
+ )
34
+
35
+ with open(output_path, "wb") as f:
36
+ for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
37
+ progress_bar.update(len(chunk))
38
+ f.write(chunk)
39
+
40
+ progress_bar.close()
41
+ else: raise ValueError(response.status_code)
42
+
43
+ return output_path
tools/utils/mediafire.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import requests
4
+
5
+ from bs4 import BeautifulSoup
6
+
7
+ def Mediafire_Download(url, output=None, filename=None):
8
+ if not filename: filename = url.split('/')[-2]
9
+ if not output: output = os.path.dirname(os.path.realpath(__file__))
10
+
11
+ output_file = os.path.join(output, filename)
12
+ sess = requests.session()
13
+ sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
14
+
15
+ try:
16
+ bs4 = BeautifulSoup(
17
+ sess.get(url).content,
18
+ "html.parser"
19
+ ).find(id="downloadButton").get("href")
20
+
21
+ with requests.get(
22
+ bs4,
23
+ stream=True
24
+ ) as r:
25
+ r.raise_for_status()
26
+
27
+ with open(output_file, "wb") as f:
28
+ total_length = int(r.headers.get('content-length'))
29
+ download_progress = 0
30
+
31
+ for chunk in r.iter_content(chunk_size=1024):
32
+ download_progress += len(chunk)
33
+ f.write(chunk)
34
+
35
+ stdout = f"\r[{filename}]: {int(100 * download_progress / total_length)}% ({round(download_progress / 1024 / 1024, 2)}mb/{round(total_length / 1024 / 1024, 2)}mb)"
36
+
37
+ sys.stdout.write(stdout)
38
+ sys.stdout.flush()
39
+
40
+ sys.stdout.write("\n")
41
+ return output_file
42
+ except Exception as e:
43
+ raise RuntimeError(e)
tools/utils/meganz.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import json
5
+ import tqdm
6
+ import codecs
7
+ import random
8
+ import base64
9
+ import struct
10
+ import shutil
11
+ import requests
12
+ import tempfile
13
+
14
+ from Crypto.Cipher import AES
15
+ from Crypto.Util import Counter
16
+
17
+ sys.path.append(os.getcwd())
18
+
19
+ from main.app.variables import translations
20
+
21
+ def makebyte(x):
22
+ return codecs.latin_1_encode(x)[0]
23
+
24
+ def a32_to_str(a):
25
+ return struct.pack('>%dI' % len(a), *a)
26
+
27
+ def get_chunks(size):
28
+ p, s = 0, 0x20000
29
+
30
+ while p + s < size:
31
+ yield(p, s)
32
+ p += s
33
+
34
+ if s < 0x100000: s += 0x20000
35
+
36
+ yield(p, size - p)
37
+
38
+ def aes_cbc_decrypt(data, key):
39
+ aes_cipher = AES.new(key, AES.MODE_CBC, makebyte('\0' * 16))
40
+
41
+ return aes_cipher.decrypt(data)
42
+
43
+ def decrypt_attr(attr, key):
44
+ attr = codecs.latin_1_decode(aes_cbc_decrypt(attr, a32_to_str(key)))[0].rstrip('\0')
45
+
46
+ return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
47
+
48
+ def _api_request(data):
49
+ sequence_num = random.randint(0, 0xFFFFFFFF)
50
+ params = {'id': sequence_num}
51
+
52
+ sequence_num += 1
53
+ if not isinstance(data, list): data = [data]
54
+
55
+ response = requests.post(
56
+ '{0}://g.api.{1}/cs'.format('https', 'mega.co.nz'),
57
+ params=params,
58
+ data=json.dumps(data),
59
+ timeout=160
60
+ )
61
+
62
+ json_resp = json.loads(response.text)
63
+ if isinstance(json_resp, int): raise Exception(json_resp)
64
+
65
+ return json_resp[0]
66
+
67
+ def base64_url_decode(data):
68
+ data += '=='[(2 - len(data) * 3) % 4:]
69
+
70
+ for search, replace in (('-', '+'), ('_', '/'), (',', '')):
71
+ data = data.replace(search, replace)
72
+
73
+ return base64.b64decode(data)
74
+
75
+ def str_to_a32(b):
76
+ if isinstance(b, str): b = makebyte(b)
77
+ if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
78
+
79
+ return struct.unpack('>%dI' % (len(b) / 4), b)
80
+
81
+ def base64_to_a32(s):
82
+ return str_to_a32(
83
+ base64_url_decode(s)
84
+ )
85
+
86
+ def mega_download_file(file_handle, file_key, dest_path=None):
87
+ file_key = base64_to_a32(file_key)
88
+ file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
89
+
90
+ k = (
91
+ file_key[0] ^ file_key[4],
92
+ file_key[1] ^ file_key[5],
93
+ file_key[2] ^ file_key[6],
94
+ file_key[3] ^ file_key[7]
95
+ )
96
+
97
+ iv = file_key[4:6] + (0, 0)
98
+ if 'g' not in file_data: raise Exception(translations["file_not_access"])
99
+
100
+ file_size = file_data['s']
101
+ attribs = decrypt_attr(base64_url_decode(file_data['at']), k)
102
+
103
+ input_file = requests.get(file_data['g'], stream=True).raw
104
+ temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
105
+
106
+ k_str = a32_to_str(k)
107
+ aes = AES.new(
108
+ k_str,
109
+ AES.MODE_CTR,
110
+ counter=Counter.new(
111
+ 128,
112
+ initial_value=(
113
+ (iv[0] << 32) + iv[1]
114
+ ) << 64
115
+ )
116
+ )
117
+
118
+ mac_str = b'\0' * 16
119
+ mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
120
+ iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
121
+
122
+ with tqdm.tqdm(total=file_size, ncols=100, unit="byte") as pbar:
123
+ for _, chunk_size in get_chunks(file_size):
124
+ chunk = aes.decrypt(input_file.read(chunk_size))
125
+ temp_output_file.write(chunk)
126
+
127
+ pbar.update(len(chunk))
128
+ encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
129
+
130
+ for i in range(0, len(chunk) - 16, 16):
131
+ block = chunk[i:i + 16]
132
+ encryptor.encrypt(block)
133
+
134
+ i = (i + 16) if file_size > 16 else 0
135
+ block = chunk[i:i + 16]
136
+
137
+ if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
138
+ mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
139
+
140
+ file_mac = str_to_a32(mac_str)
141
+ temp_output_file.close()
142
+
143
+ if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != file_key[6:8]: raise ValueError(translations["mac_not_match"])
144
+
145
+ file_path = os.path.join(dest_path, attribs['n'])
146
+ if os.path.exists(file_path): os.remove(file_path)
147
+
148
+ shutil.move(temp_output_file.name, file_path)
149
+ return file_path
150
+
151
+ def mega_download_url(url, dest_path=None):
152
+ if '/file/' in url:
153
+ url = url.replace(' ', '')
154
+ file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
155
+ path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
156
+ elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
157
+ else: raise Exception(translations["missing_url"])
158
+
159
+ return mega_download_file(path[0], path[1], dest_path)
tools/utils/noisereduce.py CHANGED
@@ -1,125 +1,58 @@
 
 
1
  import torch
2
- import tempfile
3
- import numpy as np
4
 
5
- from joblib import Parallel, delayed
6
  from torch.nn.functional import conv1d, conv2d
7
 
8
- @torch.no_grad()
9
- def amp_to_db(x, eps = torch.finfo(torch.float32).eps, top_db = 40):
10
- x_db = 20 * torch.log10(x.abs() + eps)
11
- return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
12
 
13
  @torch.no_grad()
14
  def temperature_sigmoid(x, x0, temp_coeff):
15
- return torch.sigmoid((x - x0) / temp_coeff)
16
 
17
  @torch.no_grad()
18
  def linspace(start, stop, num = 50, endpoint = True, **kwargs):
19
- return torch.linspace(start, stop, num, **kwargs) if endpoint else torch.linspace(start, stop, num + 1, **kwargs)[:-1]
20
-
21
- def _smoothing_filter(n_grad_freq, n_grad_time):
22
- smoothing_filter = np.outer(np.concatenate([np.linspace(0, 1, n_grad_freq + 1, endpoint=False), np.linspace(1, 0, n_grad_freq + 2)])[1:-1], np.concatenate([np.linspace(0, 1, n_grad_time + 1, endpoint=False), np.linspace(1, 0, n_grad_time + 2)])[1:-1])
23
- return smoothing_filter / np.sum(smoothing_filter)
 
 
 
 
 
 
 
 
 
 
24
 
25
- class SpectralGate:
26
- def __init__(self, y, sr, prop_decrease, chunk_size, padding, n_fft, win_length, hop_length, time_constant_s, freq_mask_smooth_hz, time_mask_smooth_ms, tmp_folder, use_tqdm, n_jobs):
27
- self.sr = sr
28
- self.flat = False
29
- y = np.array(y)
30
-
31
- if len(y.shape) == 1:
32
- self.y = np.expand_dims(y, 0)
33
- self.flat = True
34
- elif len(y.shape) > 2: raise ValueError
35
- else: self.y = y
36
-
37
- self._dtype = y.dtype
38
- self.n_channels, self.n_frames = self.y.shape
39
- self._chunk_size = chunk_size
40
- self.padding = padding
41
- self.n_jobs = n_jobs
42
- self.use_tqdm = use_tqdm
43
- self._tmp_folder = tmp_folder
44
- self._n_fft = n_fft
45
- self._win_length = self._n_fft if win_length is None else win_length
46
- self._hop_length = (self._win_length // 4) if hop_length is None else hop_length
47
- self._time_constant_s = time_constant_s
48
- self._prop_decrease = prop_decrease
49
-
50
- if (freq_mask_smooth_hz is None) & (time_mask_smooth_ms is None): self.smooth_mask = False
51
- else: self._generate_mask_smoothing_filter(freq_mask_smooth_hz, time_mask_smooth_ms)
52
-
53
- def _generate_mask_smoothing_filter(self, freq_mask_smooth_hz, time_mask_smooth_ms):
54
- if freq_mask_smooth_hz is None: n_grad_freq = 1
55
- else:
56
- n_grad_freq = int(freq_mask_smooth_hz / (self.sr / (self._n_fft / 2)))
57
- if n_grad_freq < 1: raise ValueError
58
 
59
- if time_mask_smooth_ms is None: n_grad_time = 1
60
- else:
61
- n_grad_time = int(time_mask_smooth_ms / ((self._hop_length / self.sr) * 1000))
62
- if n_grad_time < 1: raise ValueError
63
 
64
- if (n_grad_time == 1) & (n_grad_freq == 1): self.smooth_mask = False
65
- else:
66
- self.smooth_mask = True
67
- self._smoothing_filter = _smoothing_filter(n_grad_freq, n_grad_time)
68
-
69
- def _read_chunk(self, i1, i2):
70
- i1b = 0 if i1 < 0 else i1
71
- i2b = self.n_frames if i2 > self.n_frames else i2
72
- chunk = np.zeros((self.n_channels, i2 - i1))
73
- chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
74
- return chunk
75
-
76
- def filter_chunk(self, start_frame, end_frame):
77
- i1 = start_frame - self.padding
78
- return self._do_filter(self._read_chunk(i1, (end_frame + self.padding)))[:, start_frame - i1: end_frame - i1]
79
-
80
- def _get_filtered_chunk(self, ind):
81
- start0 = ind * self._chunk_size
82
- end0 = (ind + 1) * self._chunk_size
83
- return self.filter_chunk(start_frame=start0, end_frame=end0)
84
-
85
- def _do_filter(self, chunk):
86
- pass
87
-
88
- def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
89
- filtered_chunk[:, pos: pos + end0 - start0] = self._get_filtered_chunk(ich)[:, start0:end0]
90
- pos += end0 - start0
91
-
92
- def get_traces(self, start_frame=None, end_frame=None):
93
- if start_frame is None: start_frame = 0
94
- if end_frame is None: end_frame = self.n_frames
95
-
96
- if self._chunk_size is not None:
97
- if end_frame - start_frame > self._chunk_size:
98
- ich1 = int(start_frame / self._chunk_size)
99
- ich2 = int((end_frame - 1) / self._chunk_size)
100
-
101
- with tempfile.NamedTemporaryFile(prefix=self._tmp_folder) as fp:
102
- filtered_chunk = np.memmap(fp, dtype=self._dtype, shape=(self.n_channels, int(end_frame - start_frame)), mode="w+")
103
- pos_list, start_list, end_list = [], [], []
104
- pos = 0
105
-
106
- for ich in range(ich1, ich2 + 1):
107
- start0 = (start_frame - ich * self._chunk_size) if ich == ich1 else 0
108
- end0 = end_frame - ich * self._chunk_size if ich == ich2 else self._chunk_size
109
- pos_list.append(pos)
110
- start_list.append(start0)
111
- end_list.append(end0)
112
- pos += end0 - start0
113
-
114
- Parallel(n_jobs=self.n_jobs)(delayed(self._iterate_chunk)(filtered_chunk, pos, end0, start0, ich) for pos, start0, end0, ich in zip(pos_list, start_list, end_list, range(ich1, ich2 + 1)))
115
- return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
116
-
117
- filtered_chunk = self.filter_chunk(start_frame=0, end_frame=end_frame)
118
- return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
119
-
120
- class TG(torch.nn.Module):
121
  @torch.no_grad()
122
- def __init__(self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  super().__init__()
124
  self.sr = sr
125
  self.nonstationary = nonstationary
@@ -146,51 +79,100 @@ class TG(torch.nn.Module):
146
  if n_grad_time < 1: raise ValueError
147
  if n_grad_time == 1 and n_grad_freq == 1: return None
148
 
149
- smoothing_filter = torch.outer(torch.cat([linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2)])[1:-1], torch.cat([linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2)])[1:-1]).unsqueeze(0).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
150
  return smoothing_filter / smoothing_filter.sum()
151
 
152
  @torch.no_grad()
153
- def _stationary_mask(self, X_db, xn = None):
154
- XN_db = amp_to_db(torch.stft(xn, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(xn.device))).to(dtype=X_db.dtype) if xn is not None else X_db
155
- std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
156
- return torch.gt(X_db, (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2))
157
 
158
  @torch.no_grad()
159
  def _nonstationary_mask(self, X_abs):
160
- X_smoothed = (conv1d(X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones(self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device).view(1, 1, -1), padding="same").view(X_abs.shape) / self.n_movemean_nonstationary)
161
- return temperature_sigmoid(((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary)
162
-
163
- def forward(self, x, xn = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  assert x.ndim == 2
165
  if x.shape[-1] < self.win_length * 2: raise Exception
166
- assert xn is None or xn.ndim == 1 or xn.ndim == 2
167
- if xn is not None and xn.shape[-1] < self.win_length * 2: raise Exception
168
-
169
- X = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device))
170
- sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X), xn)
171
 
172
- sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
173
- if self.smoothing_filter is not None: sig_mask = conv2d(sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  Y = X * sig_mask.squeeze(1)
176
- return torch.istft(Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device)).to(dtype=x.dtype)
177
-
178
- class StreamedTorchGate(SpectralGate):
179
- def __init__(self, y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, n_std_thresh_stationary=1.5, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, n_jobs=1, device="cpu"):
180
- super().__init__(y=y, sr=sr, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, tmp_folder=tmp_folder, prop_decrease=prop_decrease, use_tqdm=use_tqdm, n_jobs=n_jobs)
181
- self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
182
-
183
- if y_noise is not None:
184
- if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary: y_noise = y_noise[: y.shape[-1]]
185
- y_noise = torch.from_numpy(y_noise).to(device)
186
- if len(y_noise.shape) == 1: y_noise = y_noise.unsqueeze(0)
187
-
188
- self.y_noise = y_noise
189
- self.tg = TG(sr=sr, nonstationary=not stationary, n_std_thresh_stationary=n_std_thresh_stationary, n_thresh_nonstationary=thresh_n_mult_nonstationary, temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary, n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr), prop_decrease=prop_decrease, n_fft=self._n_fft, win_length=self._win_length, hop_length=self._hop_length, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms).to(device)
190
-
191
- def _do_filter(self, chunk):
192
- if type(chunk) is np.ndarray: chunk = torch.from_numpy(chunk).to(self.device)
193
- return self.tg(x=chunk, xn=self.y_noise).cpu().detach().numpy()
194
 
195
- def reduce_noise(y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, device="cpu"):
196
- return StreamedTorchGate(y=y, sr=sr, stationary=stationary, y_noise=y_noise, prop_decrease=prop_decrease, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, thresh_n_mult_nonstationary=thresh_n_mult_nonstationary, sigmoid_slope_nonstationary=sigmoid_slope_nonstationary, tmp_folder=tmp_folder, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, clip_noise_stationary=clip_noise_stationary, use_tqdm=use_tqdm, n_jobs=1, device=device).get_traces()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
  import torch
 
 
4
 
 
5
  from torch.nn.functional import conv1d, conv2d
6
 
7
+ sys.path.append(os.getcwd())
 
 
 
8
 
9
  @torch.no_grad()
10
  def temperature_sigmoid(x, x0, temp_coeff):
11
+ return ((x - x0) / temp_coeff).sigmoid()
12
 
13
  @torch.no_grad()
14
  def linspace(start, stop, num = 50, endpoint = True, **kwargs):
15
+ return (
16
+ torch.linspace(
17
+ start,
18
+ stop,
19
+ num,
20
+ **kwargs
21
+ )
22
+ ) if endpoint else (
23
+ torch.linspace(
24
+ start,
25
+ stop,
26
+ num + 1,
27
+ **kwargs
28
+ )[:-1]
29
+ )
30
 
31
+ @torch.no_grad()
32
+ def amp_to_db(x, eps=torch.finfo(torch.float32).eps, top_db=40):
33
+ x_db = 20 * (x + eps).log10()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ return x_db.max(
36
+ (x_db.max(-1).values - top_db).unsqueeze(-1)
37
+ )
 
38
 
39
+ class TorchGate(torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @torch.no_grad()
41
+ def __init__(
42
+ self,
43
+ sr,
44
+ nonstationary = False,
45
+ n_std_thresh_stationary = 1.5,
46
+ n_thresh_nonstationary = 1.3,
47
+ temp_coeff_nonstationary = 0.1,
48
+ n_movemean_nonstationary = 20,
49
+ prop_decrease = 1.0,
50
+ n_fft = 1024,
51
+ win_length = None,
52
+ hop_length = None,
53
+ freq_mask_smooth_hz = 500,
54
+ time_mask_smooth_ms = 50
55
+ ):
56
  super().__init__()
57
  self.sr = sr
58
  self.nonstationary = nonstationary
 
79
  if n_grad_time < 1: raise ValueError
80
  if n_grad_time == 1 and n_grad_freq == 1: return None
81
 
82
+ smoothing_filter = torch.outer(
83
+ torch.cat([
84
+ linspace(0, 1, n_grad_freq + 1, endpoint=False),
85
+ linspace(1, 0, n_grad_freq + 2)
86
+ ])[1:-1],
87
+ torch.cat([
88
+ linspace(0, 1, n_grad_time + 1, endpoint=False),
89
+ linspace(1, 0, n_grad_time + 2)
90
+ ])[1:-1]
91
+ ).unsqueeze(0).unsqueeze(0)
92
+
93
  return smoothing_filter / smoothing_filter.sum()
94
 
95
  @torch.no_grad()
96
+ def _stationary_mask(self, X_db):
97
+ std_freq_noise, mean_freq_noise = torch.std_mean(X_db, dim=-1)
98
+ return X_db > (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2)
 
99
 
100
  @torch.no_grad()
101
  def _nonstationary_mask(self, X_abs):
102
+ X_smoothed = (
103
+ conv1d(
104
+ X_abs.reshape(-1, 1, X_abs.shape[-1]),
105
+ torch.ones(
106
+ self.n_movemean_nonstationary,
107
+ dtype=X_abs.dtype,
108
+ device=X_abs.device
109
+ ).view(1, 1, -1),
110
+ padding="same"
111
+ ).view(X_abs.shape) / self.n_movemean_nonstationary
112
+ )
113
+
114
+ return temperature_sigmoid(
115
+ ((X_abs - X_smoothed) / X_smoothed),
116
+ self.n_thresh_nonstationary,
117
+ self.temp_coeff_nonstationary
118
+ )
119
+
120
+ def forward(self, x):
121
  assert x.ndim == 2
122
  if x.shape[-1] < self.win_length * 2: raise Exception
 
 
 
 
 
123
 
124
+ if str(x.device).startswith(("ocl", "privateuseone")):
125
+ if not hasattr(self, "stft"):
126
+ from main.library.backends.utils import STFT
127
+
128
+ self.stft = STFT(
129
+ filter_length=self.n_fft,
130
+ hop_length=self.hop_length,
131
+ win_length=self.win_length,
132
+ pad_mode="constant"
133
+ ).to(x.device)
134
+
135
+ X, phase = self.stft.transform(
136
+ x,
137
+ eps=1e-9,
138
+ return_phase=True
139
+ )
140
+ else:
141
+ X = torch.stft(
142
+ x,
143
+ n_fft=self.n_fft,
144
+ hop_length=self.hop_length,
145
+ win_length=self.win_length,
146
+ return_complex=True,
147
+ pad_mode="constant",
148
+ center=True,
149
+ window=torch.hann_window(self.win_length).to(x.device)
150
+ )
151
+
152
+ sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X.abs()))
153
+ sig_mask = self.prop_decrease * (sig_mask.float() * 1.0 - 1.0) + 1.0
154
+
155
+ if self.smoothing_filter is not None:
156
+ sig_mask = conv2d(
157
+ sig_mask.unsqueeze(1),
158
+ self.smoothing_filter.to(sig_mask.dtype),
159
+ padding="same"
160
+ )
161
 
162
  Y = X * sig_mask.squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ return (
165
+ self.stft.inverse(
166
+ Y,
167
+ phase
168
+ )
169
+ ) if hasattr(self, "stft") else (
170
+ torch.istft(
171
+ Y,
172
+ n_fft=self.n_fft,
173
+ hop_length=self.hop_length,
174
+ win_length=self.win_length,
175
+ center=True,
176
+ window=torch.hann_window(self.win_length).to(Y.device)
177
+ ).to(dtype=x.dtype)
178
+ )
tools/utils/pixeldrain.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ def pixeldrain(url, output_dir):
5
+ try:
6
+ response = requests.get(f"https://pixeldrain.com/api/file/{url.split('pixeldrain.com/u/')[1]}")
7
+
8
+ if response.status_code == 200:
9
+ file_path = os.path.join(
10
+ output_dir,
11
+ response.headers.get("Content-Disposition").split("filename=")[-1].strip('";')
12
+ )
13
+
14
+ with open(file_path, "wb") as newfile:
15
+ newfile.write(response.content)
16
+
17
+ return file_path
18
+
19
+ return None
20
+ except Exception as e:
21
+ raise RuntimeError(e)