# from concurrent.futures import ProcessPoolExecutor, as_completed # import time # from datetime import timedelta # import pandas as pd # import torch # import warnings # import logging # import os # import traceback # # --- Load and filter dataframe --- # df = pd.read_csv("/home/ubuntu/ttsar/ASR_DATA/train_large.csv") # print('before filtering: ') # print(df.shape) # df = df[~df['filename'].str.contains("Sakura, Moyu")] # print('after filtering: ') # print(df.shape) # total_samples = len(df) # # --- PyTorch settings --- # torch.set_float32_matmul_precision('high') # torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cudnn.allow_tf32 = True # def process_batch(batch_data): # """Process a batch of audio files""" # batch_id, start_idx, audio_files, config_path, checkpoint_path = batch_data # model = None # Initialize model to None for the finally block # try: # # Import and configure libraries within the worker process # import torch # import nemo.collections.asr as nemo_asr # from omegaconf import OmegaConf, open_dict # import warnings # import logging # # Suppress logs within the worker process to keep the main output clean # logging.getLogger('nemo_logger').setLevel(logging.ERROR) # logging.disable(logging.CRITICAL) # warnings.filterwarnings('ignore') # # Load model for this worker # config = OmegaConf.load(config_path) # with open_dict(config.cfg): # for ds in ['train_ds', 'validation_ds', 'test_ds']: # if ds in config.cfg: # config.cfg[ds].defer_setup = True # model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg) # checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) # model.load_state_dict(checkpoint['state_dict'], strict=False) # model = model.eval().cuda() # decode_cfg = model.cfg.decoding # decode_cfg.beam.beam_size = 4 # model.change_decoding_strategy(decode_cfg) # # Transcribe # start = time.time() # hypotheses = model.transcribe( # audio=audio_files, # batch_size=64, # source_lang='ja', # target_lang='ja', # task='asr', # pnc='no', # verbose=False, # num_workers=0, # channel_selector=0 # ) # results = [hyp.text for hyp in hypotheses] # return batch_id, start_idx, results, len(audio_files), time.time() - start # finally: # # NEW: Ensure GPU memory is cleared in the worker process # if model is not None: # del model # import torch # torch.cuda.empty_cache() # # --- Parameters --- # chunk_size = 512 * 4 # n_workers = 4 # checkpoint_interval = 250_000 # config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml" # checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt" # # --- Prepare data chunks --- # audio_files = df['filename'].tolist() # chunks = [] # for i in range(0, total_samples, chunk_size): # end_idx = min(i + chunk_size, total_samples) # chunk_files = audio_files[i:end_idx] # chunks.append({ # 'batch_id': len(chunks), # 'start_idx': i, # 'files': chunk_files, # 'config_path': config_path, # 'checkpoint_path': checkpoint_path # }) # print(f"Processing {total_samples:,} samples") # print(f"Chunks: {len(chunks)} × ~{chunk_size} samples") # print(f"Workers: {n_workers}") # print(f"Checkpoint interval: every {checkpoint_interval:,} samples") # print("-" * 50) # # --- Initialize tracking variables --- # all_results = {} # failed_chunks = [] # start_time = time.time() # samples_done = 0 # last_checkpoint = 0 # interrupted = False # # Initialize 'text' column with a placeholder # df['text'] = pd.NA # # --- Main Processing Loop with Graceful Shutdown --- # try: # with ProcessPoolExecutor(max_workers=n_workers) as executor: # future_to_chunk = { # executor.submit(process_batch, # (chunk['batch_id'], chunk['start_idx'], chunk['files'], chunk['config_path'], chunk['checkpoint_path'])): chunk # for chunk in chunks # } # for future in as_completed(future_to_chunk): # original_chunk = future_to_chunk[future] # batch_id = original_chunk['batch_id'] # try: # _batch_id, start_idx, results, count, batch_time = future.result() # all_results[start_idx] = results # samples_done += count # end_idx = start_idx + len(results) # if len(df.iloc[start_idx:end_idx]) == len(results): # df.loc[start_idx:end_idx-1, 'text'] = results # else: # raise ValueError(f"Length mismatch: DataFrame slice vs results") # elapsed = time.time() - start_time # speed = samples_done / elapsed if elapsed > 0 else 0 # remaining = total_samples - samples_done # eta = remaining / speed if speed > 0 else 0 # print(f"✓ Batch {batch_id}/{len(chunks)-1} done ({count} samples in {batch_time:.1f}s) | " # f"Total: {samples_done:,}/{total_samples:,} ({100*samples_done/total_samples:.1f}%) | " # f"Speed: {speed:.1f} samples/s | " # f"ETA: {timedelta(seconds=int(eta))}") # if samples_done - last_checkpoint >= checkpoint_interval or samples_done == total_samples: # checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{samples_done}.csv" # df.to_csv(checkpoint_file, index=False) # print(f" ✓ Checkpoint saved: {checkpoint_file}") # last_checkpoint = samples_done # except Exception: # failed_chunks.append(original_chunk) # print("-" * 20 + " ERROR " + "-" * 20) # print(f"✗ Batch {batch_id} FAILED. Start index: {original_chunk['start_idx']}. Files: {len(original_chunk['files'])}") # traceback.print_exc() # print("-" * 47) # except KeyboardInterrupt: # interrupted = True # print("\n\n" + "="*50) # print("! KEYBOARD INTERRUPT DETECTED !") # print("Stopping workers and saving all completed progress...") # print("The script will exit shortly.") # print("="*50 + "\n") # # The `with ProcessPoolExecutor` context manager will automatically # # handle shutting down the worker processes when we exit this block. # # --- Finalization and Reporting (this block now runs on completion OR interruption) --- # total_time = time.time() - start_time # print("-" * 50) # if interrupted: # print(f"PROCESS INTERRUPTED") # else: # print(f"TRANSCRIPTION COMPLETE!") # print(f"Total time elapsed: {timedelta(seconds=int(total_time))}") # if total_time > 0 and samples_done > 0: # print(f"Average speed (on completed work): {samples_done/total_time:.1f} samples/second") # # Save final result # final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv" # df.to_csv(final_output, index=False) # print(f"Final progress saved to: {final_output}") # print("-" * 50) # # --- Summary and Verification --- # successful_transcriptions = df['text'].notna().sum() # print("Final Run Summary:") # print(f" - Successfully transcribed: {successful_transcriptions:,} samples") # print(f" - Failed batches: {len(failed_chunks)}") # print(f" - Total samples in failed batches: {sum(len(c['files']) for c in failed_chunks):,}") # if failed_chunks: # failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt" # with open(failed_files_path, 'w') as f: # for chunk in failed_chunks: # for file_path in chunk['files']: # f.write(f"{file_path}\n") # print(f"\nList of files from failed batches saved to: {failed_files_path}") # print("-" * 50) #NOTE #NOTE from concurrent.futures import ProcessPoolExecutor, as_completed import time from datetime import timedelta import pandas as pd import torch import warnings import logging import os import traceback # --- LOAD CHECKPOINT --- checkpoint_file = "/home/ubuntu/ttsar/csv_kanad/sing/cg_shani_sing.csv" print(f"Loading checkpoint from: {checkpoint_file}") df = pd.read_csv(checkpoint_file) print(f"Checkpoint loaded. Shape: {df.shape}") # Check if 'text' column exists, if not create it if 'text' not in df.columns: df['text'] = pd.NA # --- FIND ALL MISSING TRANSCRIPTIONS --- missing_mask = df['text'].isna() missing_indices = df[missing_mask].index.tolist() already_done = (~missing_mask).sum() print(f"Already transcribed: {already_done:,} samples") print(f"Missing transcriptions: {len(missing_indices):,} samples") print("-" * 50) if len(missing_indices) == 0: print("All samples already transcribed!") exit(0) # --- PyTorch settings --- torch.set_float32_matmul_precision('high') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def process_batch(batch_data): """Process a batch of audio files""" batch_id, indices, audio_files, config_path, checkpoint_path = batch_data model = None try: # Import and configure libraries within the worker process import torch import nemo.collections.asr as nemo_asr from omegaconf import OmegaConf, open_dict import warnings import logging # Suppress logs within the worker process logging.getLogger('nemo_logger').setLevel(logging.ERROR) logging.disable(logging.CRITICAL) warnings.filterwarnings('ignore') # Load model for this worker config = OmegaConf.load(config_path) with open_dict(config.cfg): for ds in ['train_ds', 'validation_ds', 'test_ds']: if ds in config.cfg: config.cfg[ds].defer_setup = True model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg) checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) model.load_state_dict(checkpoint['state_dict'], strict=False) model = model.eval().cuda().bfloat16() decode_cfg = model.cfg.decoding decode_cfg.beam.beam_size = 1 model.change_decoding_strategy(decode_cfg) # Transcribe start = time.time() try: hypotheses = model.transcribe( audio=audio_files, batch_size=64, source_lang='ja', target_lang='ja', task='asr', pnc='no', verbose=False, num_workers=0, channel_selector=0 ) results = [hyp.text for hyp in hypotheses] except Exception as e: print(f"Transcription error in batch {batch_id}: {str(e)}") # Return empty results list on transcription failure results = [] # Pad results with None if we got fewer results than expected while len(results) < len(audio_files): results.append(None) # Count successful transcriptions success_count = len([r for r in results if r is not None]) # Return indices and results as a tuple for pairing return batch_id, list(zip(indices, results)), success_count, time.time() - start finally: if model is not None: del model import torch torch.cuda.empty_cache() # --- Parameters --- chunk_size = 512 * 4 # 2048 n_workers = 6 checkpoint_interval = 250_000 config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml" checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt" # --- Create batches from missing indices --- chunks = [] for i in range(0, len(missing_indices), chunk_size): batch_indices = missing_indices[i:i+chunk_size] batch_files = df.loc[batch_indices, 'filename'].tolist() chunks.append({ 'batch_id': len(chunks), 'indices': batch_indices, 'files': batch_files, 'config_path': config_path, 'checkpoint_path': checkpoint_path }) print(f"Total batches to process: {len(chunks)}") print(f"Batch size: ~{chunk_size} samples") print(f"Workers: {n_workers}") print(f"Checkpoint interval: every {checkpoint_interval:,} samples") print("-" * 50) # --- Initialize tracking variables --- all_results = {} failed_chunks = [] failed_files_list = [] start_time = time.time() samples_done = 0 samples_failed = 0 last_checkpoint = 0 interrupted = False total_to_process = len(missing_indices) # --- Main Processing Loop --- try: with ProcessPoolExecutor(max_workers=n_workers) as executor: future_to_chunk = { executor.submit(process_batch, (chunk['batch_id'], chunk['indices'], chunk['files'], chunk['config_path'], chunk['checkpoint_path'])): chunk for chunk in chunks } for future in as_completed(future_to_chunk): original_chunk = future_to_chunk[future] batch_id = original_chunk['batch_id'] try: _batch_id, index_result_pairs, success_count, batch_time = future.result() # Update DataFrame with results failed_in_batch = 0 for idx, result in index_result_pairs: if result is not None: df.loc[idx, 'text'] = result else: df.loc[idx, 'text'] = "[FAILED]" failed_in_batch += 1 failed_files_list.append(df.loc[idx, 'filename']) samples_done += success_count samples_failed += failed_in_batch elapsed = time.time() - start_time speed = samples_done / elapsed if elapsed > 0 else 0 remaining = total_to_process - samples_done - samples_failed eta = remaining / speed if speed > 0 else 0 current_total = already_done + samples_done status = f"✓ Batch {batch_id}/{len(chunks)-1} done ({success_count} success" if failed_in_batch > 0: status += f", {failed_in_batch} failed" status += f" in {batch_time:.1f}s)" print(f"{status} | " f"Processed: {samples_done:,}/{total_to_process:,} | " f"Total: {current_total:,}/{len(df):,} ({100*current_total/len(df):.1f}%) | " f"Speed: {speed:.1f} samples/s | " f"ETA: {timedelta(seconds=int(eta))}") # Save checkpoint if samples_done - last_checkpoint >= checkpoint_interval or (samples_done + samples_failed) >= total_to_process: checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{current_total}.csv" df.to_csv(checkpoint_file, index=False) print(f" ✓ Checkpoint saved: {checkpoint_file}") last_checkpoint = samples_done except Exception as e: failed_chunks.append(original_chunk) print("-" * 20 + " ERROR " + "-" * 20) print(f"✗ Batch {batch_id} FAILED. Indices count: {len(original_chunk['indices'])}") print(f"Error: {str(e)}") traceback.print_exc() print("-" * 47) except KeyboardInterrupt: interrupted = True print("\n\n" + "="*50) print("! KEYBOARD INTERRUPT DETECTED !") print("Stopping workers and saving progress...") print("="*50 + "\n") # --- Finalization --- total_time = time.time() - start_time print("-" * 50) if interrupted: print(f"PROCESS INTERRUPTED") else: print(f"PROCESSING COMPLETE!") print(f"Session time: {timedelta(seconds=int(total_time))}") print(f"Samples successfully processed: {samples_done:,}") print(f"Samples failed: {samples_failed:,}") if total_time > 0 and samples_done > 0: print(f"Average speed: {samples_done/total_time:.1f} samples/second") # Save final result final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv" df.to_csv(final_output, index=False) print(f"Final output saved to: {final_output}") print("-" * 50) # --- Summary --- successful_transcriptions = df['text'].notna().sum() - (df['text'] == "[FAILED]").sum() failed_transcriptions = (df['text'] == "[FAILED]").sum() remaining_missing = df['text'].isna().sum() print("Summary:") print(f" - Total dataset size: {len(df):,} samples") print(f" - Successfully transcribed: {successful_transcriptions:,} samples") print(f" - Failed transcriptions: {failed_transcriptions:,} samples") print(f" - Still missing (NaN): {remaining_missing:,} samples") print(f" - Processed this session: {samples_done:,} successful, {samples_failed:,} failed") print(f" - Failed batches (entire batch): {len(failed_chunks)}") # Save list of failed files if failed_files_list: failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt" with open(failed_files_path, 'w') as f: for file_path in failed_files_list: f.write(f"{file_path}\n") print(f"\nFailed files saved to: {failed_files_path}") if failed_chunks: failed_batches_path = "/home/ubuntu/ttsar/ASR_DATA/failed_batches.txt" with open(failed_batches_path, 'w') as f: for chunk in failed_chunks: f.write(f"Batch {chunk['batch_id']}: indices {chunk['indices'][:5]}... ({len(chunk['indices'])} total)\n") print(f"Failed batch info saved to: {failed_batches_path}") print("-" * 50)