| # from concurrent.futures import ProcessPoolExecutor, as_completed | |
| # import time | |
| # from datetime import timedelta | |
| # import pandas as pd | |
| # import torch | |
| # import warnings | |
| # import logging | |
| # import os | |
| # import traceback | |
| # # --- Load and filter dataframe --- | |
| # df = pd.read_csv("/home/ubuntu/ttsar/ASR_DATA/train_large.csv") | |
| # print('before filtering: ') | |
| # print(df.shape) | |
| # df = df[~df['filename'].str.contains("Sakura, Moyu")] | |
| # print('after filtering: ') | |
| # print(df.shape) | |
| # total_samples = len(df) | |
| # # --- PyTorch settings --- | |
| # torch.set_float32_matmul_precision('high') | |
| # torch.backends.cuda.matmul.allow_tf32 = True | |
| # torch.backends.cudnn.allow_tf32 = True | |
| # def process_batch(batch_data): | |
| # """Process a batch of audio files""" | |
| # batch_id, start_idx, audio_files, config_path, checkpoint_path = batch_data | |
| # model = None # Initialize model to None for the finally block | |
| # try: | |
| # # Import and configure libraries within the worker process | |
| # import torch | |
| # import nemo.collections.asr as nemo_asr | |
| # from omegaconf import OmegaConf, open_dict | |
| # import warnings | |
| # import logging | |
| # # Suppress logs within the worker process to keep the main output clean | |
| # logging.getLogger('nemo_logger').setLevel(logging.ERROR) | |
| # logging.disable(logging.CRITICAL) | |
| # warnings.filterwarnings('ignore') | |
| # # Load model for this worker | |
| # config = OmegaConf.load(config_path) | |
| # with open_dict(config.cfg): | |
| # for ds in ['train_ds', 'validation_ds', 'test_ds']: | |
| # if ds in config.cfg: | |
| # config.cfg[ds].defer_setup = True | |
| # model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg) | |
| # checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) | |
| # model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| # model = model.eval().cuda() | |
| # decode_cfg = model.cfg.decoding | |
| # decode_cfg.beam.beam_size = 4 | |
| # model.change_decoding_strategy(decode_cfg) | |
| # # Transcribe | |
| # start = time.time() | |
| # hypotheses = model.transcribe( | |
| # audio=audio_files, | |
| # batch_size=64, | |
| # source_lang='ja', | |
| # target_lang='ja', | |
| # task='asr', | |
| # pnc='no', | |
| # verbose=False, | |
| # num_workers=0, | |
| # channel_selector=0 | |
| # ) | |
| # results = [hyp.text for hyp in hypotheses] | |
| # return batch_id, start_idx, results, len(audio_files), time.time() - start | |
| # finally: | |
| # # NEW: Ensure GPU memory is cleared in the worker process | |
| # if model is not None: | |
| # del model | |
| # import torch | |
| # torch.cuda.empty_cache() | |
| # # --- Parameters --- | |
| # chunk_size = 512 * 4 | |
| # n_workers = 4 | |
| # checkpoint_interval = 250_000 | |
| # config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml" | |
| # checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt" | |
| # # --- Prepare data chunks --- | |
| # audio_files = df['filename'].tolist() | |
| # chunks = [] | |
| # for i in range(0, total_samples, chunk_size): | |
| # end_idx = min(i + chunk_size, total_samples) | |
| # chunk_files = audio_files[i:end_idx] | |
| # chunks.append({ | |
| # 'batch_id': len(chunks), | |
| # 'start_idx': i, | |
| # 'files': chunk_files, | |
| # 'config_path': config_path, | |
| # 'checkpoint_path': checkpoint_path | |
| # }) | |
| # print(f"Processing {total_samples:,} samples") | |
| # print(f"Chunks: {len(chunks)} × ~{chunk_size} samples") | |
| # print(f"Workers: {n_workers}") | |
| # print(f"Checkpoint interval: every {checkpoint_interval:,} samples") | |
| # print("-" * 50) | |
| # # --- Initialize tracking variables --- | |
| # all_results = {} | |
| # failed_chunks = [] | |
| # start_time = time.time() | |
| # samples_done = 0 | |
| # last_checkpoint = 0 | |
| # interrupted = False | |
| # # Initialize 'text' column with a placeholder | |
| # df['text'] = pd.NA | |
| # # --- Main Processing Loop with Graceful Shutdown --- | |
| # try: | |
| # with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
| # future_to_chunk = { | |
| # executor.submit(process_batch, | |
| # (chunk['batch_id'], chunk['start_idx'], chunk['files'], chunk['config_path'], chunk['checkpoint_path'])): chunk | |
| # for chunk in chunks | |
| # } | |
| # for future in as_completed(future_to_chunk): | |
| # original_chunk = future_to_chunk[future] | |
| # batch_id = original_chunk['batch_id'] | |
| # try: | |
| # _batch_id, start_idx, results, count, batch_time = future.result() | |
| # all_results[start_idx] = results | |
| # samples_done += count | |
| # end_idx = start_idx + len(results) | |
| # if len(df.iloc[start_idx:end_idx]) == len(results): | |
| # df.loc[start_idx:end_idx-1, 'text'] = results | |
| # else: | |
| # raise ValueError(f"Length mismatch: DataFrame slice vs results") | |
| # elapsed = time.time() - start_time | |
| # speed = samples_done / elapsed if elapsed > 0 else 0 | |
| # remaining = total_samples - samples_done | |
| # eta = remaining / speed if speed > 0 else 0 | |
| # print(f"✓ Batch {batch_id}/{len(chunks)-1} done ({count} samples in {batch_time:.1f}s) | " | |
| # f"Total: {samples_done:,}/{total_samples:,} ({100*samples_done/total_samples:.1f}%) | " | |
| # f"Speed: {speed:.1f} samples/s | " | |
| # f"ETA: {timedelta(seconds=int(eta))}") | |
| # if samples_done - last_checkpoint >= checkpoint_interval or samples_done == total_samples: | |
| # checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{samples_done}.csv" | |
| # df.to_csv(checkpoint_file, index=False) | |
| # print(f" ✓ Checkpoint saved: {checkpoint_file}") | |
| # last_checkpoint = samples_done | |
| # except Exception: | |
| # failed_chunks.append(original_chunk) | |
| # print("-" * 20 + " ERROR " + "-" * 20) | |
| # print(f"✗ Batch {batch_id} FAILED. Start index: {original_chunk['start_idx']}. Files: {len(original_chunk['files'])}") | |
| # traceback.print_exc() | |
| # print("-" * 47) | |
| # except KeyboardInterrupt: | |
| # interrupted = True | |
| # print("\n\n" + "="*50) | |
| # print("! KEYBOARD INTERRUPT DETECTED !") | |
| # print("Stopping workers and saving all completed progress...") | |
| # print("The script will exit shortly.") | |
| # print("="*50 + "\n") | |
| # # The `with ProcessPoolExecutor` context manager will automatically | |
| # # handle shutting down the worker processes when we exit this block. | |
| # # --- Finalization and Reporting (this block now runs on completion OR interruption) --- | |
| # total_time = time.time() - start_time | |
| # print("-" * 50) | |
| # if interrupted: | |
| # print(f"PROCESS INTERRUPTED") | |
| # else: | |
| # print(f"TRANSCRIPTION COMPLETE!") | |
| # print(f"Total time elapsed: {timedelta(seconds=int(total_time))}") | |
| # if total_time > 0 and samples_done > 0: | |
| # print(f"Average speed (on completed work): {samples_done/total_time:.1f} samples/second") | |
| # # Save final result | |
| # final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv" | |
| # df.to_csv(final_output, index=False) | |
| # print(f"Final progress saved to: {final_output}") | |
| # print("-" * 50) | |
| # # --- Summary and Verification --- | |
| # successful_transcriptions = df['text'].notna().sum() | |
| # print("Final Run Summary:") | |
| # print(f" - Successfully transcribed: {successful_transcriptions:,} samples") | |
| # print(f" - Failed batches: {len(failed_chunks)}") | |
| # print(f" - Total samples in failed batches: {sum(len(c['files']) for c in failed_chunks):,}") | |
| # if failed_chunks: | |
| # failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt" | |
| # with open(failed_files_path, 'w') as f: | |
| # for chunk in failed_chunks: | |
| # for file_path in chunk['files']: | |
| # f.write(f"{file_path}\n") | |
| # print(f"\nList of files from failed batches saved to: {failed_files_path}") | |
| # print("-" * 50) | |
| #NOTE #NOTE | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| import time | |
| from datetime import timedelta | |
| import pandas as pd | |
| import torch | |
| import warnings | |
| import logging | |
| import os | |
| import traceback | |
| # --- LOAD CHECKPOINT --- | |
| checkpoint_file = "/home/ubuntu/ttsar/csv_kanad/sing/cg_shani_sing.csv" | |
| print(f"Loading checkpoint from: {checkpoint_file}") | |
| df = pd.read_csv(checkpoint_file) | |
| print(f"Checkpoint loaded. Shape: {df.shape}") | |
| # Check if 'text' column exists, if not create it | |
| if 'text' not in df.columns: | |
| df['text'] = pd.NA | |
| # --- FIND ALL MISSING TRANSCRIPTIONS --- | |
| missing_mask = df['text'].isna() | |
| missing_indices = df[missing_mask].index.tolist() | |
| already_done = (~missing_mask).sum() | |
| print(f"Already transcribed: {already_done:,} samples") | |
| print(f"Missing transcriptions: {len(missing_indices):,} samples") | |
| print("-" * 50) | |
| if len(missing_indices) == 0: | |
| print("All samples already transcribed!") | |
| exit(0) | |
| # --- PyTorch settings --- | |
| torch.set_float32_matmul_precision('high') | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| def process_batch(batch_data): | |
| """Process a batch of audio files""" | |
| batch_id, indices, audio_files, config_path, checkpoint_path = batch_data | |
| model = None | |
| try: | |
| # Import and configure libraries within the worker process | |
| import torch | |
| import nemo.collections.asr as nemo_asr | |
| from omegaconf import OmegaConf, open_dict | |
| import warnings | |
| import logging | |
| # Suppress logs within the worker process | |
| logging.getLogger('nemo_logger').setLevel(logging.ERROR) | |
| logging.disable(logging.CRITICAL) | |
| warnings.filterwarnings('ignore') | |
| # Load model for this worker | |
| config = OmegaConf.load(config_path) | |
| with open_dict(config.cfg): | |
| for ds in ['train_ds', 'validation_ds', 'test_ds']: | |
| if ds in config.cfg: | |
| config.cfg[ds].defer_setup = True | |
| model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg) | |
| checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) | |
| model.load_state_dict(checkpoint['state_dict'], strict=False) | |
| model = model.eval().cuda().bfloat16() | |
| decode_cfg = model.cfg.decoding | |
| decode_cfg.beam.beam_size = 1 | |
| model.change_decoding_strategy(decode_cfg) | |
| # Transcribe | |
| start = time.time() | |
| try: | |
| hypotheses = model.transcribe( | |
| audio=audio_files, | |
| batch_size=64, | |
| source_lang='ja', | |
| target_lang='ja', | |
| task='asr', | |
| pnc='no', | |
| verbose=False, | |
| num_workers=0, | |
| channel_selector=0 | |
| ) | |
| results = [hyp.text for hyp in hypotheses] | |
| except Exception as e: | |
| print(f"Transcription error in batch {batch_id}: {str(e)}") | |
| # Return empty results list on transcription failure | |
| results = [] | |
| # Pad results with None if we got fewer results than expected | |
| while len(results) < len(audio_files): | |
| results.append(None) | |
| # Count successful transcriptions | |
| success_count = len([r for r in results if r is not None]) | |
| # Return indices and results as a tuple for pairing | |
| return batch_id, list(zip(indices, results)), success_count, time.time() - start | |
| finally: | |
| if model is not None: | |
| del model | |
| import torch | |
| torch.cuda.empty_cache() | |
| # --- Parameters --- | |
| chunk_size = 512 * 4 # 2048 | |
| n_workers = 6 | |
| checkpoint_interval = 250_000 | |
| config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml" | |
| checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt" | |
| # --- Create batches from missing indices --- | |
| chunks = [] | |
| for i in range(0, len(missing_indices), chunk_size): | |
| batch_indices = missing_indices[i:i+chunk_size] | |
| batch_files = df.loc[batch_indices, 'filename'].tolist() | |
| chunks.append({ | |
| 'batch_id': len(chunks), | |
| 'indices': batch_indices, | |
| 'files': batch_files, | |
| 'config_path': config_path, | |
| 'checkpoint_path': checkpoint_path | |
| }) | |
| print(f"Total batches to process: {len(chunks)}") | |
| print(f"Batch size: ~{chunk_size} samples") | |
| print(f"Workers: {n_workers}") | |
| print(f"Checkpoint interval: every {checkpoint_interval:,} samples") | |
| print("-" * 50) | |
| # --- Initialize tracking variables --- | |
| all_results = {} | |
| failed_chunks = [] | |
| failed_files_list = [] | |
| start_time = time.time() | |
| samples_done = 0 | |
| samples_failed = 0 | |
| last_checkpoint = 0 | |
| interrupted = False | |
| total_to_process = len(missing_indices) | |
| # --- Main Processing Loop --- | |
| try: | |
| with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
| future_to_chunk = { | |
| executor.submit(process_batch, | |
| (chunk['batch_id'], chunk['indices'], chunk['files'], | |
| chunk['config_path'], chunk['checkpoint_path'])): chunk | |
| for chunk in chunks | |
| } | |
| for future in as_completed(future_to_chunk): | |
| original_chunk = future_to_chunk[future] | |
| batch_id = original_chunk['batch_id'] | |
| try: | |
| _batch_id, index_result_pairs, success_count, batch_time = future.result() | |
| # Update DataFrame with results | |
| failed_in_batch = 0 | |
| for idx, result in index_result_pairs: | |
| if result is not None: | |
| df.loc[idx, 'text'] = result | |
| else: | |
| df.loc[idx, 'text'] = "[FAILED]" | |
| failed_in_batch += 1 | |
| failed_files_list.append(df.loc[idx, 'filename']) | |
| samples_done += success_count | |
| samples_failed += failed_in_batch | |
| elapsed = time.time() - start_time | |
| speed = samples_done / elapsed if elapsed > 0 else 0 | |
| remaining = total_to_process - samples_done - samples_failed | |
| eta = remaining / speed if speed > 0 else 0 | |
| current_total = already_done + samples_done | |
| status = f"✓ Batch {batch_id}/{len(chunks)-1} done ({success_count} success" | |
| if failed_in_batch > 0: | |
| status += f", {failed_in_batch} failed" | |
| status += f" in {batch_time:.1f}s)" | |
| print(f"{status} | " | |
| f"Processed: {samples_done:,}/{total_to_process:,} | " | |
| f"Total: {current_total:,}/{len(df):,} ({100*current_total/len(df):.1f}%) | " | |
| f"Speed: {speed:.1f} samples/s | " | |
| f"ETA: {timedelta(seconds=int(eta))}") | |
| # Save checkpoint | |
| if samples_done - last_checkpoint >= checkpoint_interval or (samples_done + samples_failed) >= total_to_process: | |
| checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{current_total}.csv" | |
| df.to_csv(checkpoint_file, index=False) | |
| print(f" ✓ Checkpoint saved: {checkpoint_file}") | |
| last_checkpoint = samples_done | |
| except Exception as e: | |
| failed_chunks.append(original_chunk) | |
| print("-" * 20 + " ERROR " + "-" * 20) | |
| print(f"✗ Batch {batch_id} FAILED. Indices count: {len(original_chunk['indices'])}") | |
| print(f"Error: {str(e)}") | |
| traceback.print_exc() | |
| print("-" * 47) | |
| except KeyboardInterrupt: | |
| interrupted = True | |
| print("\n\n" + "="*50) | |
| print("! KEYBOARD INTERRUPT DETECTED !") | |
| print("Stopping workers and saving progress...") | |
| print("="*50 + "\n") | |
| # --- Finalization --- | |
| total_time = time.time() - start_time | |
| print("-" * 50) | |
| if interrupted: | |
| print(f"PROCESS INTERRUPTED") | |
| else: | |
| print(f"PROCESSING COMPLETE!") | |
| print(f"Session time: {timedelta(seconds=int(total_time))}") | |
| print(f"Samples successfully processed: {samples_done:,}") | |
| print(f"Samples failed: {samples_failed:,}") | |
| if total_time > 0 and samples_done > 0: | |
| print(f"Average speed: {samples_done/total_time:.1f} samples/second") | |
| # Save final result | |
| final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv" | |
| df.to_csv(final_output, index=False) | |
| print(f"Final output saved to: {final_output}") | |
| print("-" * 50) | |
| # --- Summary --- | |
| successful_transcriptions = df['text'].notna().sum() - (df['text'] == "[FAILED]").sum() | |
| failed_transcriptions = (df['text'] == "[FAILED]").sum() | |
| remaining_missing = df['text'].isna().sum() | |
| print("Summary:") | |
| print(f" - Total dataset size: {len(df):,} samples") | |
| print(f" - Successfully transcribed: {successful_transcriptions:,} samples") | |
| print(f" - Failed transcriptions: {failed_transcriptions:,} samples") | |
| print(f" - Still missing (NaN): {remaining_missing:,} samples") | |
| print(f" - Processed this session: {samples_done:,} successful, {samples_failed:,} failed") | |
| print(f" - Failed batches (entire batch): {len(failed_chunks)}") | |
| # Save list of failed files | |
| if failed_files_list: | |
| failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt" | |
| with open(failed_files_path, 'w') as f: | |
| for file_path in failed_files_list: | |
| f.write(f"{file_path}\n") | |
| print(f"\nFailed files saved to: {failed_files_path}") | |
| if failed_chunks: | |
| failed_batches_path = "/home/ubuntu/ttsar/ASR_DATA/failed_batches.txt" | |
| with open(failed_batches_path, 'w') as f: | |
| for chunk in failed_chunks: | |
| f.write(f"Batch {chunk['batch_id']}: indices {chunk['indices'][:5]}... ({len(chunk['indices'])} total)\n") | |
| print(f"Failed batch info saved to: {failed_batches_path}") | |
| print("-" * 50) |