ASR / src /data /nptel_loader.py
MihirRPatil's picture
deploy: CDAC ASR backend with pitch/stress fix and LLM feedback
88a679b
Raw
History Blame Contribute Delete
8.67 kB
import os
import re
import requests
import tarfile
import io
import time
import soundfile as sf
from pathlib import Path
from tqdm import tqdm
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
class NPTELChunkedLoader:
"""
A loader for the NPTEL2020 dataset. The dataset is a SINGLE large tar.gz
split into 20 parts (partaa-partat). These parts must be concatenated
into one stream before decompression.
Since the tar stores ALL .wav files before ALL .txt files, this loader
uses a TWO-PASS streaming approach:
Pass 1: Stream the entire dataset, collecting only .txt transcriptions (~10MB RAM).
Pass 2: Stream again, matching each .wav with pre-loaded text and yielding immediately.
This doubles download time but uses almost zero RAM and zero disk.
"""
def __init__(self, script_path):
self.script_path = script_path
self.urls = self._parse_script(script_path)
print(f"✓ Found {len(self.urls)} dataset parts in {script_path}")
# Setup a robust session
self.session = requests.Session()
retry_strategy = Retry(
total=5,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
self.session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
def _parse_script(self, path):
urls = []
if not os.path.exists(path):
print(f"Warning: Script {path} not found.")
return urls
with open(path, "r") as f:
for line in f:
match = re.search(r"https://zenodo.org/record/\d+/files/[^\s]+", line)
if match:
urls.append(match.group(0))
return urls
def _make_concatenated_stream(self, desc="NPTEL Stream"):
"""Creates a streaming reader that concatenates all 20 parts over HTTP."""
class RobustConcatenatedStream(io.RawIOBase):
def __init__(self, loader, urls, desc):
self.loader = loader
self.urls = urls
self.current_url_idx = 0
self.bytes_read_in_current_url = 0
self.current_resp = None
self.current_iter = None
self.pbar = tqdm(unit='B', unit_scale=True, desc=desc, dynamic_ncols=True)
def _get_next_iter(self):
url = self.urls[self.current_url_idx]
max_retries = 10
for attempt in range(max_retries):
try:
headers = {}
if self.bytes_read_in_current_url > 0:
headers['Range'] = f"bytes={self.bytes_read_in_current_url}-"
resp = self.loader.session.get(url, stream=True, timeout=(15, 300), headers=headers)
if resp.status_code == 200 and self.bytes_read_in_current_url > 0:
print(f"\n! Server doesn't support Resuming. Skipping {self.bytes_read_in_current_url} bytes...")
it = resp.iter_content(chunk_size=1024*1024)
skipped = 0
while skipped < self.bytes_read_in_current_url:
chunk = next(it)
skipped += len(chunk)
self.current_iter = it
else:
resp.raise_for_status()
self.current_iter = resp.iter_content(chunk_size=128*1024)
self.current_resp = resp
if self.current_url_idx == 0 and self.bytes_read_in_current_url == 0:
content_length = resp.headers.get('content-length')
if content_length:
self.pbar.total = int(content_length)
return
except Exception as e:
print(f"\n! Connection lost ({e}). Retry {attempt+1}/{max_retries} in 5s...")
time.sleep(5)
raise ConnectionError("Maximum retries exceeded for dataset stream.")
def readinto(self, b):
if self.current_iter is None:
if self.current_url_idx >= len(self.urls):
return 0
self._get_next_iter()
try:
chunk = next(self.current_iter)
n = len(chunk)
b[:n] = chunk
self.bytes_read_in_current_url += n
self.pbar.update(n)
return n
except StopIteration:
self.current_url_idx += 1
self.bytes_read_in_current_url = 0
self.current_iter = None
return self.readinto(b)
except Exception as e:
print(f"\n! Stream interrupted mid-chunk: {e}. Attempting to resume...")
self.current_iter = None
return self.readinto(b)
def close(self):
if hasattr(self, 'pbar'):
self.pbar.close()
super().close()
return RobustConcatenatedStream(self, self.urls, desc)
def stream_data(self):
"""Two-pass streaming: collect text first, then match audio."""
# ============ PASS 1: Collect all transcriptions ============
print("\n" + "="*60)
print("📝 PASS 1/2: Streaming dataset to collect transcriptions...")
print(" (Only reading .txt files, skipping audio to save RAM)")
print("="*60)
text_map = {}
stream1 = self._make_concatenated_stream(desc="Pass 1 (Text)")
try:
with tarfile.open(fileobj=stream1, mode='r|gz') as tar:
for member in tar:
if not member.isfile():
continue
if member.name.endswith(".txt"):
sid = os.path.basename(member.name).replace(".txt", "")
f = tar.extractfile(member)
if f:
text_map[sid] = f.read().decode("utf-8").strip()
if len(text_map) % 10000 == 0:
print(f" Collected {len(text_map)} transcriptions...")
except Exception as e:
print(f"\n⚠️ Pass 1 stream error: {e}")
finally:
stream1.close()
print(f"\n✓ Pass 1 complete: Collected {len(text_map)} transcriptions ({len(text_map)*100//1024//1024} MB est.)")
if not text_map:
print("❌ No transcriptions found! Cannot proceed.")
return
# ============ PASS 2: Stream audio and yield matches ============
print("\n" + "="*60)
print("🔊 PASS 2/2: Streaming dataset to read audio & yield samples...")
print("="*60)
matched = 0
stream2 = self._make_concatenated_stream(desc="Pass 2 (Audio)")
try:
with tarfile.open(fileobj=stream2, mode='r|gz') as tar:
for member in tar:
if not member.isfile():
continue
if member.name.endswith(".wav"):
sid = os.path.basename(member.name).replace(".wav", "")
if sid in text_map:
f = tar.extractfile(member)
if f:
audio_data, sr = sf.read(io.BytesIO(f.read()))
matched += 1
if matched % 1000 == 0:
print(f" ✓ Yielded {matched} samples...")
yield {
"id": sid,
"audio": {"array": audio_data, "sampling_rate": sr},
"transcription": text_map[sid]
}
except Exception as e:
print(f"\n⚠️ Pass 2 stream error: {e}")
finally:
stream2.close()
print(f"\n✅ Complete! Yielded {matched} total matched samples.")
def get_iterable_dataset(self):
from datasets import IterableDataset
return IterableDataset.from_generator(self.stream_data)