"""Evoxtral RL — Rejection sampling + SFT on best completions (RAFT). Follows the GRPO-for-ASR approach (arxiv:2509.01939) simplified for hackathon: 1. Generate N completions per training sample (with sampling) 2. Score each with rule-based reward (WER + Tag F1 + hallucination penalty) 3. Keep best completion per sample 4. SFT on the curated high-quality dataset (1 epoch, lower LR) Usage: modal run scripts/rl_modal.py modal run scripts/rl_modal.py --num-samples 4 --lr 5e-5 """ import os import modal image = ( modal.Image.debian_slim(python_version="3.11") .apt_install("ffmpeg", "libsndfile1") .pip_install( "torch>=2.4.0", "torchaudio>=2.4.0", "transformers==4.56.0", "datasets>=2.14.0", "accelerate>=1.0.0", "peft>=0.13.0", "wandb>=0.18.0", "jiwer>=3.0.0", "librosa>=0.10.0", "soundfile>=0.12.0", "huggingface_hub", "safetensors", "sentencepiece", "mistral-common", "torchcodec", gpu="A10G", ) .env({ "HF_HUB_CACHE": "/cache/huggingface", }) ) app = modal.App("evoxtral-rl", image=image) hf_cache = modal.Volume.from_name("evoxtral-hf-cache", create_if_missing=True) data_vol = modal.Volume.from_name("evoxtral-data", create_if_missing=True) output_vol = modal.Volume.from_name("evoxtral-output", create_if_missing=True) VOLUMES = { "/cache/huggingface": hf_cache, "/data": data_vol, "/output": output_vol, } MODEL_ID = "mistralai/Voxtral-Mini-3B-2507" SFT_ADAPTER = "/output/evoxtral-lora" RL_OUTPUT = "/output/evoxtral-rl" @app.function( gpu="A10G", volumes=VOLUMES, secrets=[ modal.Secret.from_name("wandb-secret"), modal.Secret.from_name("huggingface-secret"), ], timeout=7200, memory=32768, ) def generate_and_score(num_samples: int = 4, temperature: float = 0.7): """Step 1: Generate N completions per sample and score them.""" import torch import json import re from pathlib import Path from collections import Counter from jiwer import wer as compute_wer from datasets import load_from_disk, Audio from transformers import VoxtralForConditionalGeneration, AutoProcessor from peft import PeftModel print(f"GPU: {torch.cuda.get_device_name(0)}") # --- Reward helpers --- def extract_tags(text): return [m.group(1).lower() for m in re.finditer(r'\[([^\]]+)\]', text)] def strip_tags(text): text = re.sub(r'\[[^\]]+\]\s*', '', text) text = re.sub(r'\b[A-Z]{2,}\b', lambda m: m.group(0).lower(), text) return text.strip() def compute_reward(prediction, reference): """Rule-based reward: WER accuracy + Tag F1 - hallucination penalty.""" # WER component (accuracy = 1 - WER) ref_plain = strip_tags(reference) pred_plain = strip_tags(prediction) if ref_plain.strip(): wer_score = compute_wer(ref_plain, pred_plain) wer_accuracy = max(0.0, 1.0 - wer_score) else: wer_accuracy = 1.0 # Tag F1 component pred_tags = Counter(extract_tags(prediction)) ref_tags = Counter(extract_tags(reference)) if not ref_tags and not pred_tags: tag_f1 = 1.0 hall_rate = 0.0 elif not ref_tags or not pred_tags: tag_f1 = 0.0 hall_rate = 1.0 if pred_tags else 0.0 else: common = sum((pred_tags & ref_tags).values()) prec = common / sum(pred_tags.values()) rec = common / sum(ref_tags.values()) tag_f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 # Hallucination rate ref_set = set(ref_tags.keys()) hallucinated = sum(v for k, v in pred_tags.items() if k not in ref_set) hall_rate = hallucinated / sum(pred_tags.values()) # Combined reward reward = 0.4 * wer_accuracy + 0.4 * tag_f1 + 0.2 * (1.0 - hall_rate) return reward, {"wer_accuracy": wer_accuracy, "tag_f1": tag_f1, "hall_rate": hall_rate} # --- Load model --- print("Loading SFT model...") processor = AutoProcessor.from_pretrained(MODEL_ID) base_model = VoxtralForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(base_model, SFT_ADAPTER) model.eval() print(f"Model loaded on {model.device}") # --- Load training data --- ds = load_from_disk("/data/processed") ds = ds.cast_column("audio", Audio(sampling_rate=16000)) train_ds = ds["train"] print(f"Generating {num_samples} completions per sample for {len(train_ds)} training examples...") import time start_time = time.time() # Resume from checkpoint if exists checkpoint_path = "/output/rl_curated_checkpoint.json" if Path(checkpoint_path).exists(): with open(checkpoint_path) as f: curated_data = json.load(f) start_idx = len(curated_data) total_reward = sum(d["reward"] for d in curated_data) print(f"Resuming from checkpoint: {start_idx} samples already done") else: curated_data = [] total_reward = 0.0 start_idx = 0 for i in range(start_idx, len(train_ds)): row = train_ds[i] reference = row["tagged_text"] audio_array = row["audio"]["array"] # Build inputs inputs = processor.apply_transcription_request( language="en", audio=[audio_array], format=["WAV"], model_id=MODEL_ID, return_tensors="pt", ) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate all N completions in one call with num_return_sequences with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=temperature, top_p=0.9, num_return_sequences=num_samples, ) input_len = inputs["input_ids"].shape[1] # Score each completion, keep best best_reward = -1.0 best_prediction = None best_details = None for s in range(num_samples): prediction = processor.tokenizer.decode( output_ids[s][input_len:], skip_special_tokens=True ) reward, details = compute_reward(prediction, reference) if reward > best_reward: best_reward = reward best_prediction = prediction best_details = details curated_data.append({ "audio_idx": i, "reference": reference, "best_prediction": best_prediction, "reward": best_reward, **best_details, }) total_reward += best_reward if i < 5 or i % 50 == 0: elapsed = time.time() - start_time done = i - start_idx + 1 rate = done / elapsed if elapsed > 0 else 0 eta = (len(train_ds) - i - 1) / rate if rate > 0 else 0 print(f" [{i}/{len(train_ds)}] reward={best_reward:.3f} " f"wer_acc={best_details['wer_accuracy']:.3f} " f"tag_f1={best_details['tag_f1']:.3f} " f"hall={best_details['hall_rate']:.3f} " f"({rate:.1f} samples/s, ETA {eta/60:.0f}min)") if i < 3: print(f" ref: {reference[:80]}...") print(f" best: {best_prediction[:80]}...") # Save checkpoint every 50 samples if (i + 1) % 50 == 0: with open(checkpoint_path, "w") as f: json.dump(curated_data, f) output_vol.commit() print(f" [checkpoint saved: {len(curated_data)} samples]") avg_reward = total_reward / len(curated_data) print(f"\nGeneration complete! Avg reward: {avg_reward:.4f}") print(f"Curated {len(curated_data)} samples") # Save curated data output_path = "/output/rl_curated_data.json" with open(output_path, "w") as f: json.dump(curated_data, f) output_vol.commit() print(f"Saved curated data to {output_path}") return {"avg_reward": avg_reward, "num_samples": len(curated_data)} @app.function( gpu="A10G", volumes=VOLUMES, secrets=[ modal.Secret.from_name("wandb-secret"), modal.Secret.from_name("huggingface-secret"), ], timeout=7200, memory=32768, ) def rl_finetune(learning_rate: float = 5e-5, num_epochs: int = 1, push_to_hub: bool = True): """Step 2: SFT on curated best completions (RAFT).""" import torch import wandb import json from pathlib import Path from datasets import Dataset, Audio, load_from_disk from transformers import ( VoxtralForConditionalGeneration, AutoProcessor, TrainingArguments, Trainer, ) from peft import PeftModel print(f"GPU: {torch.cuda.get_device_name(0)}") # --- Load curated data --- with open("/output/rl_curated_data.json") as f: curated_data = json.load(f) print(f"Loaded {len(curated_data)} curated samples") # Filter out low-reward samples (bottom 10%) rewards = [d["reward"] for d in curated_data] threshold = sorted(rewards)[len(rewards) // 10] curated_data = [d for d in curated_data if d["reward"] > threshold] print(f"After filtering (reward > {threshold:.3f}): {len(curated_data)} samples") # --- Load original audio dataset --- ds = load_from_disk("/data/processed") ds = ds.cast_column("audio", Audio(sampling_rate=16000)) train_ds = ds["train"] # Build RL training dataset: use original audio + best prediction as target rl_examples = [] for d in curated_data: idx = d["audio_idx"] row = train_ds[idx] rl_examples.append({ "audio": row["audio"], "tagged_text": d["best_prediction"], # RL target = best sampled completion }) rl_dataset = Dataset.from_list(rl_examples) rl_dataset = rl_dataset.cast_column("audio", Audio(sampling_rate=16000)) print(f"RL training dataset: {len(rl_dataset)} samples") # --- W&B --- run = wandb.init( project="evoxtral", name=f"rl-raft-lr{learning_rate}-ep{num_epochs}", config={ "method": "RAFT (rejection sampling + SFT)", "base_adapter": "evoxtral-lora (SFT)", "learning_rate": learning_rate, "epochs": num_epochs, "num_curated": len(rl_dataset), "reward_threshold": threshold, }, tags=["evoxtral", "rl", "raft", "rejection-sampling"], ) avg_reward = sum(d["reward"] for d in curated_data) / len(curated_data) wandb.log({"rl/curated_samples": len(curated_data), "rl/avg_reward": avg_reward}) # --- Load SFT model --- print("Loading SFT model for RL finetuning...") processor = AutoProcessor.from_pretrained(MODEL_ID) base_model = VoxtralForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) model = PeftModel.from_pretrained(base_model, SFT_ADAPTER) # Unfreeze LoRA for continued training for name, param in model.named_parameters(): if "lora" in name.lower(): param.requires_grad = True model.print_trainable_parameters() # --- Data Collator (same as SFT) --- class VoxtralDataCollator: def __init__(self, processor, model_id, max_text_len=512): self.processor = processor self.model_id = model_id self.max_text_len = max_text_len self.pad_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id def __call__(self, examples): texts = [ex["tagged_text"] for ex in examples] audios = [ex["audio"]["array"] for ex in examples] prompt = self.processor.apply_transcription_request( language="en", model_id=self.model_id, audio=audios, format=["WAV"] * len(audios), return_tensors="pt", ) passthrough = {k: v for k, v in prompt.items() if k not in ("input_ids", "attention_mask")} prompt_ids = prompt["input_ids"] prompt_attn = prompt["attention_mask"] B = prompt_ids.size(0) tok = self.processor.tokenizer text_tok = tok( texts, add_special_tokens=False, padding=False, truncation=True, max_length=self.max_text_len, return_tensors=None, ) text_ids_list = text_tok["input_ids"] input_ids, attention_mask, labels = [], [], [] for i in range(B): p_ids = prompt_ids[i].tolist() p_att = prompt_attn[i].tolist() t_ids = text_ids_list[i] ids = p_ids + t_ids + [tok.eos_token_id] attn = p_att + [1] * (len(t_ids) + 1) lab = [-100] * len(p_ids) + t_ids + [tok.eos_token_id] input_ids.append(ids) attention_mask.append(attn) labels.append(lab) max_len = max(len(x) for x in input_ids) def pad_to(seq, fill, L): return seq + [fill] * (L - len(seq)) input_ids = [pad_to(x, self.pad_id, max_len) for x in input_ids] attention_mask = [pad_to(x, 0, max_len) for x in attention_mask] labels = [pad_to(x, -100, max_len) for x in labels] batch = { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } for k, v in passthrough.items(): batch[k] = v return batch collator = VoxtralDataCollator(processor, MODEL_ID) # --- Training args (lower LR, 1 epoch) --- training_args = TrainingArguments( output_dir=RL_OUTPUT, num_train_epochs=num_epochs, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=learning_rate, lr_scheduler_type="cosine", warmup_steps=20, weight_decay=0.01, max_grad_norm=1.0, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, logging_steps=5, save_strategy="epoch", save_total_limit=2, report_to="wandb", remove_unused_columns=False, dataloader_pin_memory=True, dataloader_num_workers=4, ) trainer = Trainer( model=model, args=training_args, train_dataset=rl_dataset, data_collator=collator, ) print("Starting RL finetuning...") train_result = trainer.train() wandb.log({ "rl/final_loss": train_result.metrics.get("train_loss", 0), "rl/runtime_seconds": train_result.metrics.get("train_runtime", 0), }) # Save adapter print(f"Saving RL adapter to {RL_OUTPUT}") trainer.save_model(RL_OUTPUT) processor.save_pretrained(RL_OUTPUT) # Log as W&B artifact artifact = wandb.Artifact( "evoxtral-rl-adapter", type="model", metadata={"method": "RAFT", "base": "evoxtral-lora"}, ) artifact.add_dir(RL_OUTPUT) run.log_artifact(artifact) # Push to Hub if push_to_hub: from huggingface_hub import HfApi HUB_ID = "YongkangZOU/evoxtral-rl" print(f"Pushing to HuggingFace Hub: {HUB_ID}") try: api = HfApi(token=os.environ.get("HF_TOKEN")) api.create_repo(HUB_ID, repo_type="model", exist_ok=True) api.upload_folder( folder_path=RL_OUTPUT, repo_id=HUB_ID, repo_type="model", commit_message=f"RL adapter (RAFT): lr={learning_rate}, ep={num_epochs}", ) print(f"Pushed to {HUB_ID}") except Exception as e: print(f"Hub push failed: {e}") output_vol.commit() wandb.finish() print("RL finetuning complete!") return train_result.metrics @app.local_entrypoint() def main( num_samples: int = 4, temperature: float = 0.7, lr: float = 5e-5, epochs: int = 1, push_to_hub: bool = True, finetune_only: bool = False, ): if not finetune_only: print("Step 1: Generating and scoring completions...") gen_results = generate_and_score.remote( num_samples=num_samples, temperature=temperature, ) print(f"Generation results: {gen_results}") print("\nStep 2: RL finetuning on curated data...") ft_results = rl_finetune.remote( learning_rate=lr, num_epochs=epochs, push_to_hub=push_to_hub, ) print(f"RL finetune results: {ft_results}") print("\nDone! Run eval with: modal run scripts/train_modal.py --eval-only") print("(Update adapter_path in evaluate() to /output/evoxtral-rl)")