feat: Setup completo para treinamento Qwen3-0.6B speech embeddings
Browse files- Implementa pipeline de treinamento baseado em LLaMA-Omni2 + LoRA-Whisper
- Adiciona validação mínima (130 samples, 15-20 minutos)
- Configura Common Voice 22 PT dataset
- Cria Speech Projector + LoRA integration
- Pipeline experimental Qwen3 para testes
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
- pipelines/llama_omni2_experimental_qwen3.py +455 -0
- training/qwen3-0.6b/README.md +233 -0
- training/qwen3-0.6b/config/training_config.yaml +189 -0
- training/qwen3-0.6b/data/prepare_cv22.py +364 -0
- training/qwen3-0.6b/data/synthetic_samples.py +288 -0
- training/qwen3-0.6b/requirements.txt +94 -0
- training/qwen3-0.6b/scripts/quick_validation.py +424 -0
- training/qwen3-0.6b/scripts/run_minimal_validation.py +361 -0
- training/qwen3-0.6b/scripts/train_stage1.py +491 -0
- training/qwen3-0.6b/scripts/utils.py +474 -0
pipelines/llama_omni2_experimental_qwen3.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LLaMA-Omni2 EXPERIMENTAL com Qwen3-0.6B
|
| 4 |
+
==========================================
|
| 5 |
+
Pipeline experimental baseado no oficial adaptado para usar Qwen3-0.6B
|
| 6 |
+
|
| 7 |
+
DIFERENÇAS DO OFICIAL:
|
| 8 |
+
=====================
|
| 9 |
+
|
| 10 |
+
1. LLM BASE: Qwen3-0.6B (ao invés de Qwen2)
|
| 11 |
+
- Modelo: "Qwen/Qwen3-0.6B" (0.6B parâmetros)
|
| 12 |
+
- Arquitetura: Qwen3ForCausalLM
|
| 13 |
+
- Hidden size: 1024 dimensões (diferente do Qwen2: 896)
|
| 14 |
+
- Vocabulário: ~152.000 tokens
|
| 15 |
+
- Modos: thinking/non-thinking
|
| 16 |
+
|
| 17 |
+
2. SPEECH PROJECTOR ADAPTADO:
|
| 18 |
+
- Output adaptado para 1024 dims (hidden_size do Qwen3)
|
| 19 |
+
- Arquitetura: Linear(6400, 2048) → ReLU → Linear(2048, 1024)
|
| 20 |
+
|
| 21 |
+
3. DEPENDÊNCIAS:
|
| 22 |
+
- transformers >= 4.51.0 (suporte ao Qwen3)
|
| 23 |
+
- torch >= 2.0
|
| 24 |
+
- Demais iguais ao oficial
|
| 25 |
+
|
| 26 |
+
ARQUITETURA EXPERIMENTAL:
|
| 27 |
+
========================
|
| 28 |
+
|
| 29 |
+
1. WHISPER ENCODER (Igual ao oficial)
|
| 30 |
+
- Modelo: whisper-large-v3 (1.55B parâmetros)
|
| 31 |
+
- Output: Embeddings [batch, time//2, 1280]
|
| 32 |
+
|
| 33 |
+
2. SPEECH PROJECTOR (Adaptado para Qwen3)
|
| 34 |
+
- Arquitetura: Linear(6400, 2048) → ReLU → Linear(2048, 1024)
|
| 35 |
+
- Output: Features projetadas [batch, seq_len, 1024]
|
| 36 |
+
|
| 37 |
+
3. LLM (Qwen3-0.6B)
|
| 38 |
+
- Modelo: Qwen3ForCausalLM (0.6B parâmetros)
|
| 39 |
+
- Hidden size: 1024 dimensões
|
| 40 |
+
- Modo: Padrão (non-thinking)
|
| 41 |
+
|
| 42 |
+
4. TTS (Igual ao oficial)
|
| 43 |
+
- Biblioteca: gTTS
|
| 44 |
+
|
| 45 |
+
NOTAS EXPERIMENTAIS:
|
| 46 |
+
===================
|
| 47 |
+
- Este é um pipeline EXPERIMENTAL para testar Qwen3
|
| 48 |
+
- Pode ter menor performance que o oficial
|
| 49 |
+
- Qwen3 pode responder de forma diferente
|
| 50 |
+
- Compatibilidade com embeddings não garantida
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
import torch
|
| 54 |
+
import torch.nn as nn
|
| 55 |
+
import numpy as np
|
| 56 |
+
import whisper
|
| 57 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
| 58 |
+
from safetensors.torch import load_file
|
| 59 |
+
import os
|
| 60 |
+
import json
|
| 61 |
+
import logging
|
| 62 |
+
from typing import Tuple, Optional
|
| 63 |
+
from gtts import gTTS
|
| 64 |
+
import tempfile
|
| 65 |
+
import soundfile as sf
|
| 66 |
+
|
| 67 |
+
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
# Constantes iguais ao oficial
|
| 71 |
+
SPEECH_TOKEN_INDEX = -200
|
| 72 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
| 73 |
+
IGNORE_INDEX = -100
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class LLaMAOmni2Qwen3Experimental:
|
| 77 |
+
"""Implementação experimental com Qwen3-0.6B"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, device="cuda"):
|
| 80 |
+
self.device = device
|
| 81 |
+
self.qwen3_model_name = "Qwen/Qwen3-0.6B"
|
| 82 |
+
|
| 83 |
+
logger.info("\n" + "="*80)
|
| 84 |
+
logger.info("🧪 LLaMA-Omni2 - Pipeline EXPERIMENTAL com Qwen3-0.6B")
|
| 85 |
+
logger.info("="*80)
|
| 86 |
+
|
| 87 |
+
# 1. Carregar Whisper
|
| 88 |
+
logger.info("📦 Carregando Whisper...")
|
| 89 |
+
self._load_whisper()
|
| 90 |
+
|
| 91 |
+
# 2. Carregar Qwen3
|
| 92 |
+
logger.info("🤖 Carregando Qwen3-0.6B...")
|
| 93 |
+
self._load_qwen3()
|
| 94 |
+
|
| 95 |
+
# 3. Criar componentes adaptados
|
| 96 |
+
logger.info("🔧 Criando componentes adaptados...")
|
| 97 |
+
self._setup_components()
|
| 98 |
+
|
| 99 |
+
# 4. gTTS para síntese
|
| 100 |
+
self.tts_enabled = True
|
| 101 |
+
|
| 102 |
+
logger.info("="*80)
|
| 103 |
+
logger.info("✅ Pipeline experimental carregado!")
|
| 104 |
+
logger.info(f"📊 Hidden size: {self.hidden_size}")
|
| 105 |
+
logger.info("="*80)
|
| 106 |
+
|
| 107 |
+
def _load_whisper(self):
|
| 108 |
+
"""Carrega Whisper (igual ao oficial)"""
|
| 109 |
+
model_path = "models/large-v3.pt"
|
| 110 |
+
if os.path.exists(model_path):
|
| 111 |
+
self.whisper_model = whisper.load_model(model_path, device=self.device)
|
| 112 |
+
else:
|
| 113 |
+
self.whisper_model = whisper.load_model("large-v3", device=self.device)
|
| 114 |
+
|
| 115 |
+
def _load_qwen3(self):
|
| 116 |
+
"""Carrega modelo Qwen3-0.6B"""
|
| 117 |
+
try:
|
| 118 |
+
# Carregar configuração primeiro
|
| 119 |
+
config = AutoConfig.from_pretrained(self.qwen3_model_name)
|
| 120 |
+
self.hidden_size = config.hidden_size
|
| 121 |
+
|
| 122 |
+
logger.info(f" • Hidden size detectado: {self.hidden_size}")
|
| 123 |
+
|
| 124 |
+
# Carregar modelo com torch_dtype consistente
|
| 125 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 126 |
+
self.qwen3_model_name,
|
| 127 |
+
torch_dtype=torch.float32, # Usar float32 consistente
|
| 128 |
+
device_map="auto",
|
| 129 |
+
trust_remote_code=True
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Detectar dtype do modelo
|
| 133 |
+
self.model_dtype = next(self.model.parameters()).dtype
|
| 134 |
+
logger.info(f" • Model dtype: {self.model_dtype}")
|
| 135 |
+
|
| 136 |
+
# Carregar tokenizer
|
| 137 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 138 |
+
self.qwen3_model_name,
|
| 139 |
+
use_fast=False,
|
| 140 |
+
trust_remote_code=True
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Configurar pad token
|
| 144 |
+
if self.tokenizer.pad_token is None:
|
| 145 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 146 |
+
|
| 147 |
+
# Adicionar speech token se não existir
|
| 148 |
+
if DEFAULT_SPEECH_TOKEN not in self.tokenizer.get_vocab():
|
| 149 |
+
self.tokenizer.add_tokens([DEFAULT_SPEECH_TOKEN])
|
| 150 |
+
logger.info(f" • Adicionado token {DEFAULT_SPEECH_TOKEN}")
|
| 151 |
+
|
| 152 |
+
self.model.eval()
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"❌ Erro ao carregar Qwen3: {e}")
|
| 156 |
+
raise e
|
| 157 |
+
|
| 158 |
+
def _setup_components(self):
|
| 159 |
+
"""Configura componentes adaptados para Qwen3"""
|
| 160 |
+
# Speech encoder (igual ao oficial)
|
| 161 |
+
self.speech_encoder = WhisperEncoder(self.whisper_model, self.device)
|
| 162 |
+
|
| 163 |
+
# Speech projector adaptado para o hidden_size do Qwen3
|
| 164 |
+
self.speech_projector = SpeechProjectorQwen3(
|
| 165 |
+
encoder_dim=1280,
|
| 166 |
+
llm_dim=self.hidden_size, # Usar hidden_size do Qwen3
|
| 167 |
+
k=5
|
| 168 |
+
).to(self.device)
|
| 169 |
+
|
| 170 |
+
logger.info(f" • Speech projector: 1280 → {self.hidden_size}")
|
| 171 |
+
|
| 172 |
+
def load_speech(self, audio: np.ndarray) -> torch.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Carrega speech (igual ao oficial)
|
| 175 |
+
"""
|
| 176 |
+
# Pad ou trim para 30 segundos
|
| 177 |
+
audio = whisper.pad_or_trim(audio)
|
| 178 |
+
|
| 179 |
+
# Criar mel spectrogram
|
| 180 |
+
mel = whisper.log_mel_spectrogram(audio, n_mels=128)
|
| 181 |
+
|
| 182 |
+
# CRÍTICO: Permutar dimensões!
|
| 183 |
+
mel = mel.permute(1, 0)
|
| 184 |
+
|
| 185 |
+
return mel
|
| 186 |
+
|
| 187 |
+
def encode_speech(self, speech_mel: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""Processa mel através do encoder e projector"""
|
| 189 |
+
# 1. Passar pelo encoder do Whisper
|
| 190 |
+
speech_features = self.speech_encoder(speech_mel)
|
| 191 |
+
|
| 192 |
+
# 2. Passar pelo projector adaptado
|
| 193 |
+
projected = self.speech_projector(speech_features)
|
| 194 |
+
|
| 195 |
+
return projected
|
| 196 |
+
|
| 197 |
+
@torch.no_grad()
|
| 198 |
+
def generate(self,
|
| 199 |
+
audio: np.ndarray,
|
| 200 |
+
max_new_tokens: int = 100,
|
| 201 |
+
temperature: float = 0.7) -> str:
|
| 202 |
+
"""
|
| 203 |
+
Gera resposta usando Qwen3
|
| 204 |
+
"""
|
| 205 |
+
# 1. Processar áudio
|
| 206 |
+
speech_mel = self.load_speech(audio)
|
| 207 |
+
|
| 208 |
+
# 2. Criar mensagens (adaptado para Qwen3)
|
| 209 |
+
messages = [
|
| 210 |
+
{"role": "user", "content": DEFAULT_SPEECH_TOKEN},
|
| 211 |
+
{"role": "assistant", "content": ""}
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
# 3. Aplicar chat template do Qwen3
|
| 215 |
+
try:
|
| 216 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 217 |
+
messages,
|
| 218 |
+
add_generation_prompt=True,
|
| 219 |
+
return_tensors="pt"
|
| 220 |
+
)[0]
|
| 221 |
+
except Exception as e:
|
| 222 |
+
# Fallback se apply_chat_template falhar
|
| 223 |
+
logger.warning(f"⚠️ Chat template falhou: {e}")
|
| 224 |
+
text = f"user: {DEFAULT_SPEECH_TOKEN}\nassistant:"
|
| 225 |
+
input_ids = self.tokenizer.encode(text, return_tensors="pt")[0]
|
| 226 |
+
|
| 227 |
+
# 4. Substituir speech token
|
| 228 |
+
input_ids[input_ids == self.tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)] = SPEECH_TOKEN_INDEX
|
| 229 |
+
input_ids = input_ids.unsqueeze(0).to(self.device)
|
| 230 |
+
|
| 231 |
+
# 5. Processar speech
|
| 232 |
+
speech_tensor = speech_mel.unsqueeze(0).to(self.device)
|
| 233 |
+
speech_features = self.encode_speech(speech_tensor)
|
| 234 |
+
|
| 235 |
+
# 6. Preparar inputs com embeddings
|
| 236 |
+
input_embeds = self.prepare_inputs_with_speech(
|
| 237 |
+
input_ids,
|
| 238 |
+
speech_features
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# 7. Gerar resposta com Qwen3
|
| 242 |
+
outputs = self.model.generate(
|
| 243 |
+
inputs_embeds=input_embeds,
|
| 244 |
+
max_new_tokens=max_new_tokens,
|
| 245 |
+
temperature=temperature,
|
| 246 |
+
do_sample=True,
|
| 247 |
+
top_p=0.95,
|
| 248 |
+
use_cache=True,
|
| 249 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 250 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 251 |
+
bos_token_id=getattr(self.tokenizer, 'bos_token_id', self.tokenizer.pad_token_id)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# 8. Decodificar resposta
|
| 255 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 256 |
+
|
| 257 |
+
# Limpar resposta (adaptado para Qwen3)
|
| 258 |
+
if "assistant" in response:
|
| 259 |
+
response = response.split("assistant")[-1].strip()
|
| 260 |
+
if "<|im_end|>" in response:
|
| 261 |
+
response = response.split("<|im_end|>")[0].strip()
|
| 262 |
+
if "<|endoftext|>" in response:
|
| 263 |
+
response = response.split("<|endoftext|>")[0].strip()
|
| 264 |
+
|
| 265 |
+
return response
|
| 266 |
+
|
| 267 |
+
def prepare_inputs_with_speech(self, input_ids, speech_features):
|
| 268 |
+
"""
|
| 269 |
+
Combina input_ids com speech features (igual ao oficial)
|
| 270 |
+
"""
|
| 271 |
+
logger.info(f" • Input IDs shape: {input_ids.shape}")
|
| 272 |
+
logger.info(f" • Speech features shape: {speech_features.shape}")
|
| 273 |
+
|
| 274 |
+
# Criar máscara
|
| 275 |
+
speech_token_mask = (input_ids == SPEECH_TOKEN_INDEX)
|
| 276 |
+
|
| 277 |
+
# Substituir por token válido temporariamente
|
| 278 |
+
temp_input_ids = input_ids.clone()
|
| 279 |
+
temp_input_ids[speech_token_mask] = self.tokenizer.pad_token_id
|
| 280 |
+
|
| 281 |
+
# Obter embeddings e garantir dtype consistente
|
| 282 |
+
input_embeds = self.model.get_input_embeddings()(temp_input_ids)
|
| 283 |
+
|
| 284 |
+
# Ajustar dtype do speech_features para match com input_embeds
|
| 285 |
+
speech_features = speech_features.to(dtype=input_embeds.dtype, device=input_embeds.device)
|
| 286 |
+
|
| 287 |
+
if speech_token_mask.any():
|
| 288 |
+
batch_size = input_ids.shape[0]
|
| 289 |
+
|
| 290 |
+
for b in range(batch_size):
|
| 291 |
+
speech_indices = torch.where(speech_token_mask[b])[0]
|
| 292 |
+
|
| 293 |
+
if len(speech_indices) > 0:
|
| 294 |
+
speech_idx = speech_indices[0].item()
|
| 295 |
+
|
| 296 |
+
# Dividir embeddings
|
| 297 |
+
before = input_embeds[b, :speech_idx]
|
| 298 |
+
after = input_embeds[b, speech_idx+1:]
|
| 299 |
+
speech = speech_features[b]
|
| 300 |
+
|
| 301 |
+
# Garantir 2D
|
| 302 |
+
if before.dim() == 1:
|
| 303 |
+
before = before.unsqueeze(0)
|
| 304 |
+
if after.dim() == 1:
|
| 305 |
+
after = after.unsqueeze(0)
|
| 306 |
+
if speech.dim() == 1:
|
| 307 |
+
speech = speech.unsqueeze(0)
|
| 308 |
+
|
| 309 |
+
# Combinar
|
| 310 |
+
parts = []
|
| 311 |
+
if before.shape[0] > 0:
|
| 312 |
+
parts.append(before)
|
| 313 |
+
if speech.shape[0] > 0:
|
| 314 |
+
parts.append(speech)
|
| 315 |
+
if after.shape[0] > 0:
|
| 316 |
+
parts.append(after)
|
| 317 |
+
|
| 318 |
+
combined = torch.cat(parts, dim=0).unsqueeze(0)
|
| 319 |
+
input_embeds = combined
|
| 320 |
+
|
| 321 |
+
return input_embeds
|
| 322 |
+
|
| 323 |
+
def synthesize_speech(self, text: str, lang: str = "pt") -> str:
|
| 324 |
+
"""Sintetiza fala com gTTS (igual ao oficial)"""
|
| 325 |
+
try:
|
| 326 |
+
tts = gTTS(text=text, lang=lang, slow=False)
|
| 327 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
|
| 328 |
+
tts.save(f.name)
|
| 329 |
+
temp_mp3 = f.name
|
| 330 |
+
|
| 331 |
+
# Converter para WAV
|
| 332 |
+
temp_wav = temp_mp3.replace(".mp3", ".wav")
|
| 333 |
+
data, sr = sf.read(temp_mp3)
|
| 334 |
+
sf.write(temp_wav, data, sr)
|
| 335 |
+
|
| 336 |
+
os.remove(temp_mp3)
|
| 337 |
+
return temp_wav
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"Erro na síntese: {e}")
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
def process(self, audio: np.ndarray) -> Tuple[str, Optional[str]]:
|
| 343 |
+
"""Pipeline completo"""
|
| 344 |
+
try:
|
| 345 |
+
# 1. Gerar texto
|
| 346 |
+
response_text = self.generate(audio)
|
| 347 |
+
logger.info(f"💬 Resposta Qwen3: {response_text}")
|
| 348 |
+
|
| 349 |
+
# 2. Sintetizar áudio
|
| 350 |
+
audio_path = None
|
| 351 |
+
if response_text and self.tts_enabled:
|
| 352 |
+
audio_path = self.synthesize_speech(response_text)
|
| 353 |
+
|
| 354 |
+
return response_text, audio_path
|
| 355 |
+
|
| 356 |
+
except Exception as e:
|
| 357 |
+
logger.error(f"❌ Erro: {e}")
|
| 358 |
+
import traceback
|
| 359 |
+
traceback.print_exc()
|
| 360 |
+
return "", None
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class WhisperEncoder(nn.Module):
|
| 364 |
+
"""Wrapper para o encoder do Whisper (igual ao oficial)"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, whisper_model, device):
|
| 367 |
+
super().__init__()
|
| 368 |
+
self.encoder = whisper_model.encoder
|
| 369 |
+
self.device = device
|
| 370 |
+
self.encoder.eval()
|
| 371 |
+
|
| 372 |
+
def forward(self, mel):
|
| 373 |
+
"""Forward através do encoder do Whisper"""
|
| 374 |
+
with torch.no_grad():
|
| 375 |
+
# Input: [batch, time, 128]
|
| 376 |
+
# Whisper espera: [batch, 128, time]
|
| 377 |
+
if mel.dim() == 3:
|
| 378 |
+
mel = mel.permute(0, 2, 1)
|
| 379 |
+
elif mel.dim() == 2:
|
| 380 |
+
mel = mel.unsqueeze(0).permute(0, 2, 1)
|
| 381 |
+
|
| 382 |
+
features = self.encoder(mel)
|
| 383 |
+
|
| 384 |
+
return features # [batch, time//2, 1280]
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class SpeechProjectorQwen3(nn.Module):
|
| 388 |
+
"""Speech Projector adaptado para Qwen3"""
|
| 389 |
+
|
| 390 |
+
def __init__(self, encoder_dim=1280, llm_dim=1024, k=5):
|
| 391 |
+
super().__init__()
|
| 392 |
+
self.k = k
|
| 393 |
+
|
| 394 |
+
# Adaptado para hidden_size do Qwen3
|
| 395 |
+
self.linear1 = nn.Linear(encoder_dim * k, 2048)
|
| 396 |
+
self.relu = nn.ReLU()
|
| 397 |
+
self.linear2 = nn.Linear(2048, llm_dim) # llm_dim será o hidden_size do Qwen3
|
| 398 |
+
|
| 399 |
+
def forward(self, x):
|
| 400 |
+
batch_size, seq_len, dim = x.size()
|
| 401 |
+
|
| 402 |
+
# Downsampling por fator k
|
| 403 |
+
num_frames_to_discard = seq_len % self.k
|
| 404 |
+
if num_frames_to_discard > 0:
|
| 405 |
+
x = x[:, :-num_frames_to_discard, :]
|
| 406 |
+
seq_len = x.size(1)
|
| 407 |
+
|
| 408 |
+
# Reshape
|
| 409 |
+
x = x.contiguous()
|
| 410 |
+
x = x.view(batch_size, seq_len // self.k, dim * self.k)
|
| 411 |
+
|
| 412 |
+
# Duas camadas
|
| 413 |
+
x = self.linear1(x)
|
| 414 |
+
x = self.relu(x)
|
| 415 |
+
x = self.linear2(x)
|
| 416 |
+
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def test_qwen3_experimental():
|
| 421 |
+
"""Testa a implementação experimental com Qwen3"""
|
| 422 |
+
print("\n" + "="*80)
|
| 423 |
+
print("🧪 TESTE EXPERIMENTAL - QWEN3-0.6B")
|
| 424 |
+
print("="*80)
|
| 425 |
+
|
| 426 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 427 |
+
|
| 428 |
+
try:
|
| 429 |
+
model = LLaMAOmni2Qwen3Experimental(device=device)
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(f"❌ Erro ao carregar modelo: {e}")
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
# Criar áudio de teste
|
| 435 |
+
print("\n📊 Testando com áudio...")
|
| 436 |
+
audio = np.random.randn(16000 * 3).astype(np.float32) * 0.01
|
| 437 |
+
|
| 438 |
+
print("🔄 Processando com Qwen3...")
|
| 439 |
+
response, audio_path = model.process(audio)
|
| 440 |
+
|
| 441 |
+
print("-"*40)
|
| 442 |
+
if response:
|
| 443 |
+
print(f"✅ SUCESSO! Resposta Qwen3: {response}")
|
| 444 |
+
else:
|
| 445 |
+
print(f"❌ Resposta vazia")
|
| 446 |
+
|
| 447 |
+
if audio_path and os.path.exists(audio_path):
|
| 448 |
+
print(f"🔊 Áudio: {audio_path}")
|
| 449 |
+
os.remove(audio_path)
|
| 450 |
+
|
| 451 |
+
print("="*80)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if __name__ == "__main__":
|
| 455 |
+
test_qwen3_experimental()
|
training/qwen3-0.6b/README.md
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎤 Qwen3-0.6B Speech Embeddings Training
|
| 2 |
+
|
| 3 |
+
## 📚 Academic Foundation & References
|
| 4 |
+
|
| 5 |
+
Baseado nas metodologias dos principais papers acadêmicos para treinamento de embeddings de fala em LLMs:
|
| 6 |
+
|
| 7 |
+
### 🎯 **Papers Fundamentais:**
|
| 8 |
+
|
| 9 |
+
1. **LLaMA-Omni2** (2025) - *LLM-based Real-time Spoken Chatbot with Autoregressive Streaming Speech Synthesis*
|
| 10 |
+
- **ArXiv**: [2505.02625](https://arxiv.org/abs/2505.02625)
|
| 11 |
+
- **Metodologia**: Two-stage training (Speech-to-Text → Speech-to-Speech)
|
| 12 |
+
- **Dataset**: InstructS2S-200K samples
|
| 13 |
+
- **Inovação**: Speech embeddings sem transcrição intermediária
|
| 14 |
+
|
| 15 |
+
2. **LoRA-Whisper** (2024) - *Parameter-Efficient and Extensible Multilingual ASR*
|
| 16 |
+
- **ArXiv**: [2406.06619](https://arxiv.org/html/2406.06619v1)
|
| 17 |
+
- **Contribuição**: Evita interferência linguística com LoRA modules específicos por idioma
|
| 18 |
+
- **Performance**: +18.5% ganho relativo em ASR multilingual
|
| 19 |
+
- **Relevância**: Demonstra eficácia do LoRA para adaptar Whisper
|
| 20 |
+
|
| 21 |
+
3. **LoRA: Low-Rank Adaptation** (2021) - *Low-Rank Adaptation of Large Language Models*
|
| 22 |
+
- **ArXiv**: [2106.09685](https://arxiv.org/abs/2106.09685)
|
| 23 |
+
- **Impacto**: Reduz parâmetros treináveis em 10.000x
|
| 24 |
+
- **Eficiência**: 3x menos memória GPU vs fine-tuning completo
|
| 25 |
+
- **Base**: Foundation for parameter-efficient speech training
|
| 26 |
+
|
| 27 |
+
4. **Speech2Vec** (2018) - *Learning Word Embeddings from Speech*
|
| 28 |
+
- **ArXiv**: [1803.08976](https://arxiv.org/abs/1803.08976)
|
| 29 |
+
- **Conceito**: Fixed-length vector representations from speech
|
| 30 |
+
- **Relevância**: Early work on semantic speech embeddings
|
| 31 |
+
|
| 32 |
+
5. **StyleSpeech** (2024) - *Parameter-efficient Fine Tuning for Pre-trained Controllable Text-to-Speech*
|
| 33 |
+
- **ArXiv**: [2408.14713](https://arxiv.org/abs/2408.14713)
|
| 34 |
+
- **Técnica**: LoRA aplicado a modelos de síntese de fala
|
| 35 |
+
- **Aplicação**: Adaptation de features de estilo com eficiência
|
| 36 |
+
|
| 37 |
+
### 🧠 **Fundamentação Teórica:**
|
| 38 |
+
|
| 39 |
+
**Por que Whisper Embeddings funcionam:**
|
| 40 |
+
- Whisper encoder produz representações semânticas ricas (1280 dims)
|
| 41 |
+
- Treinado em 680K horas de áudio multilingual
|
| 42 |
+
- Captura informações prosódicas e fonéticas além do conteúdo
|
| 43 |
+
|
| 44 |
+
**Por que LoRA é essencial:**
|
| 45 |
+
- Evita *catastrophic forgetting* do conhecimento pré-treinado
|
| 46 |
+
- Permite especialização para embeddings de fala
|
| 47 |
+
- Reduz drasticamente tempo e recursos de treinamento
|
| 48 |
+
|
| 49 |
+
**Arquitetura Speech Adapter:**
|
| 50 |
+
```
|
| 51 |
+
Whisper Encoder [1280] → Speech Projector [1024] → Qwen3 + LoRA
|
| 52 |
+
↓ Frozen ↓ Trainable ↓ LoRA adapters
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## 🎯 **Objetivo do Treinamento**
|
| 56 |
+
|
| 57 |
+
Ensinar o **Qwen3-0.6B** a entender embeddings de fala do **Whisper Large-v3** através de:
|
| 58 |
+
|
| 59 |
+
1. **Speech Projector**: Mapear Whisper[1280] → Qwen3[1024]
|
| 60 |
+
2. **LoRA Fine-tuning**: Adaptar Qwen3 para processar embeddings de fala
|
| 61 |
+
3. **Common Voice PT**: Dataset português para **transcrição básica** (validação inicial)
|
| 62 |
+
|
| 63 |
+
**FOCO INICIAL**: Testar se o modelo consegue "ouvir" áudio e repetir/transcrever o que ouviu.
|
| 64 |
+
|
| 65 |
+
## 🗂️ **Estrutura do Treinamento**
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
training/qwen3-0.6b/
|
| 69 |
+
├── README.md # Este arquivo
|
| 70 |
+
├── config/
|
| 71 |
+
│ ├── training_config.yaml # Hiperparâmetros de treinamento
|
| 72 |
+
│ ├── lora_config.yaml # Configuração LoRA
|
| 73 |
+
│ └── dataset_config.yaml # Configuração de dataset
|
| 74 |
+
├── scripts/
|
| 75 |
+
│ ├── train_stage1.py # Stage I: Speech-to-Text
|
| 76 |
+
│ ├── train_stage2.py # Stage II: Speech-to-Speech
|
| 77 |
+
│ ├── prepare_dataset.py # Preprocessamento Common Voice
|
| 78 |
+
│ ├── evaluate_model.py # Avaliação e métricas
|
| 79 |
+
│ └── utils.py # Funções auxiliares
|
| 80 |
+
├── models/
|
| 81 |
+
│ ├── speech_adapter.py # Speech Projector implementation
|
| 82 |
+
│ ├── lora_qwen3.py # Qwen3 com LoRA integration
|
| 83 |
+
│ └── training_pipeline.py # Pipeline completo
|
| 84 |
+
├── data/
|
| 85 |
+
│ ├── prepare_cv22.py # Processar Common Voice 22
|
| 86 |
+
│ ├── synthetic_samples.py # Gerar samples sintéticos
|
| 87 |
+
│ └── portuguese_instructions.json # Instruções em PT-BR
|
| 88 |
+
└── checkpoints/ # Modelos salvos durante treinamento
|
| 89 |
+
├── stage1_best.pt
|
| 90 |
+
├── stage2_best.pt
|
| 91 |
+
└── final_model.pt
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## ⚙️ **Configuração de Treinamento**
|
| 95 |
+
|
| 96 |
+
### **Hardware Mínimo:**
|
| 97 |
+
- 1x RTX 4090 (24GB VRAM)
|
| 98 |
+
- 32GB RAM sistem
|
| 99 |
+
- 500GB SSD espaço livre
|
| 100 |
+
|
| 101 |
+
### **Hardware Ideal:**
|
| 102 |
+
- 4x RTX 4090 (96GB VRAM total)
|
| 103 |
+
- 128GB RAM
|
| 104 |
+
- NVMe SSD 2TB+
|
| 105 |
+
|
| 106 |
+
### **Tempo Estimado:**
|
| 107 |
+
- **LoRA (Recommended)**: 8-12 horas
|
| 108 |
+
- **Full Fine-tuning**: 48-72 horas
|
| 109 |
+
- **Dataset prep**: 2-4 horas
|
| 110 |
+
|
| 111 |
+
## 📊 **Metodologia Baseada em Papers**
|
| 112 |
+
|
| 113 |
+
### **Stage I: Speech-to-Text Training**
|
| 114 |
+
```python
|
| 115 |
+
# Baseado em LLaMA-Omni2 Stage I(a)
|
| 116 |
+
- Freeze: Whisper encoder
|
| 117 |
+
- Train: Speech Projector + Qwen3 (LoRA)
|
| 118 |
+
- Epochs: 3
|
| 119 |
+
- Batch Size: 32
|
| 120 |
+
- Learning Rate: 5e-5 (LoRA), 5e-4 (Projector)
|
| 121 |
+
- Optimizer: AdamW
|
| 122 |
+
- Scheduler: Cosine with warmup
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### **Stage II: Speech-to-Speech Enhancement**
|
| 126 |
+
```python
|
| 127 |
+
# Opcional - para síntese de fala
|
| 128 |
+
- Freeze: Whisper + Speech Projector + Qwen3
|
| 129 |
+
- Train: TTS components (se aplicável)
|
| 130 |
+
- Epochs: 1
|
| 131 |
+
- Learning Rate: 1e-3
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## 🎛️ **Configuração LoRA (Otimizada)**
|
| 135 |
+
|
| 136 |
+
Baseado em **LoRA-Whisper** e análises de eficiência:
|
| 137 |
+
|
| 138 |
+
```yaml
|
| 139 |
+
lora_config:
|
| 140 |
+
r: 16 # Rank (balance between efficiency/performance)
|
| 141 |
+
alpha: 32 # Scaling factor (2x rank is optimal)
|
| 142 |
+
dropout: 0.1 # Prevent overfitting
|
| 143 |
+
target_modules: # Apply to attention matrices
|
| 144 |
+
- "q_proj" # Query projection
|
| 145 |
+
- "k_proj" # Key projection
|
| 146 |
+
- "v_proj" # Value projection
|
| 147 |
+
- "o_proj" # Output projection
|
| 148 |
+
bias: "none" # Don't adapt bias terms
|
| 149 |
+
task_type: "CAUSAL_LM" # Causal language modeling
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
## 📈 **Métricas de Avaliação**
|
| 153 |
+
|
| 154 |
+
### **Métricas Primárias:**
|
| 155 |
+
1. **Perplexity**: Quão bem o modelo "entende" embeddings
|
| 156 |
+
2. **BLEU Score**: Qualidade das respostas geradas
|
| 157 |
+
3. **Semantic Similarity**: Cosine similarity entre embeddings
|
| 158 |
+
4. **Response Coherence**: Avaliação humana das respostas
|
| 159 |
+
|
| 160 |
+
### **Métricas Secundárias:**
|
| 161 |
+
1. **Training Loss**: Convergência durante treinamento
|
| 162 |
+
2. **Validation Loss**: Overfitting detection
|
| 163 |
+
3. **Memory Usage**: Eficiência de recursos
|
| 164 |
+
4. **Inference Speed**: Latência de resposta
|
| 165 |
+
|
| 166 |
+
## 🌍 **Dataset: Common Voice 22 + Instruções PT-BR**
|
| 167 |
+
|
| 168 |
+
### **Common Voice Statistics:**
|
| 169 |
+
- **Português**: ~300 horas de áudio validado
|
| 170 |
+
- **Speakers**: ~15.000 falantes únicos
|
| 171 |
+
- **Diversity**: Sotaques regionais do Brasil/Portugal
|
| 172 |
+
- **Quality**: Crowd-sourced, quality-controlled
|
| 173 |
+
|
| 174 |
+
### **Augmentation Strategy:**
|
| 175 |
+
1. **Instruction Rewriting**: Converter frases CV para instruções
|
| 176 |
+
2. **Response Generation**: GPT-4 gerar respostas em PT-BR
|
| 177 |
+
3. **Audio Synthesis**: TTS para respostas (opcional)
|
| 178 |
+
4. **Noise Augmentation**: Simular condições reais
|
| 179 |
+
|
| 180 |
+
## 🔬 **Experimentos Planejados**
|
| 181 |
+
|
| 182 |
+
### **Baseline Experiments:**
|
| 183 |
+
1. **E1**: LoRA r=8 vs r=16 vs r=32
|
| 184 |
+
2. **E2**: Projector hidden_dim 1024 vs 2048 vs 4096
|
| 185 |
+
3. **E3**: Dataset size 10K vs 50K vs 200K samples
|
| 186 |
+
4. **E4**: Learning rate schedules comparison
|
| 187 |
+
|
| 188 |
+
### **Advanced Experiments:**
|
| 189 |
+
1. **E5**: Multi-lingual training (PT + EN)
|
| 190 |
+
2. **E6**: Adapter vs LoRA vs Full fine-tuning
|
| 191 |
+
3. **E7**: Different Whisper model sizes
|
| 192 |
+
4. **E8**: Synthetic vs Real audio comparison
|
| 193 |
+
|
| 194 |
+
## 🎯 **Success Criteria**
|
| 195 |
+
|
| 196 |
+
### **Minimum Viable Performance:**
|
| 197 |
+
- [ ] Model generates non-empty responses to speech input
|
| 198 |
+
- [ ] Responses are coherent and relevant
|
| 199 |
+
- [ ] Training converges without overfitting
|
| 200 |
+
- [ ] BLEU score > 0.15 on test set
|
| 201 |
+
|
| 202 |
+
### **Target Performance:**
|
| 203 |
+
- [ ] BLEU score > 0.35 (competitive with baselines)
|
| 204 |
+
- [ ] Perplexity < 15 on speech embeddings
|
| 205 |
+
- [ ] Response latency < 1 second
|
| 206 |
+
- [ ] Handles Portuguese speech variations
|
| 207 |
+
|
| 208 |
+
### **Stretch Goals:**
|
| 209 |
+
- [ ] Multilingual capability (PT + EN)
|
| 210 |
+
- [ ] Real-time inference (< 500ms)
|
| 211 |
+
- [ ] Emotion/prosody understanding
|
| 212 |
+
- [ ] Few-shot learning for new domains
|
| 213 |
+
|
| 214 |
+
## 📚 **Próximos Passos**
|
| 215 |
+
|
| 216 |
+
1. **Setup Environment** (`pip install -r requirements.txt`)
|
| 217 |
+
2. **Prepare Dataset** (`python data/prepare_cv22.py`)
|
| 218 |
+
3. **Run Stage I** (`python scripts/train_stage1.py`)
|
| 219 |
+
4. **Evaluate Results** (`python scripts/evaluate_model.py`)
|
| 220 |
+
5. **Deploy & Test** (Integration with main pipeline)
|
| 221 |
+
|
| 222 |
+
## 🤝 **Contribuições**
|
| 223 |
+
|
| 224 |
+
Este treinamento é baseado em metodologias state-of-the-art e representa uma aplicação prática dos avanços acadêmicos em speech embeddings para LLMs.
|
| 225 |
+
|
| 226 |
+
**Key Innovations:**
|
| 227 |
+
- Primeira aplicação de LoRA-Whisper methodology ao Qwen3
|
| 228 |
+
- Dataset brasileiro estruturado para instruction-following
|
| 229 |
+
- Pipeline end-to-end para speech-to-speech em português
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
*"Teaching machines to truly understand speech, not just transcribe it."* 🎤✨
|
training/qwen3-0.6b/config/training_config.yaml
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎤 Qwen3-0.6B Speech Embeddings Training Configuration
|
| 2 |
+
# Based on LLaMA-Omni2 official methodology + LoRA-Whisper best practices
|
| 3 |
+
|
| 4 |
+
model:
|
| 5 |
+
name: "Qwen/Qwen3-0.6B"
|
| 6 |
+
hidden_size: 1024
|
| 7 |
+
device: "cuda"
|
| 8 |
+
torch_dtype: "float32" # For compatibility
|
| 9 |
+
trust_remote_code: true
|
| 10 |
+
|
| 11 |
+
# Whisper configuration
|
| 12 |
+
whisper:
|
| 13 |
+
model_name: "large-v3"
|
| 14 |
+
model_path: "/workspace/llama-omni2-compact/models/large-v3.pt" # Caminho absoluto atualizado
|
| 15 |
+
encoder_dim: 1280
|
| 16 |
+
freeze_encoder: true # CRITICAL: Always freeze Whisper
|
| 17 |
+
|
| 18 |
+
# Speech Projector configuration
|
| 19 |
+
speech_projector:
|
| 20 |
+
input_dim: 1280 # Whisper encoder output
|
| 21 |
+
hidden_dim: 2048 # Following LLaMA-Omni2 paper
|
| 22 |
+
output_dim: 1024 # Qwen3-0.6B hidden size
|
| 23 |
+
downsample_factor: 5 # k=5 as in original paper
|
| 24 |
+
dropout: 0.1
|
| 25 |
+
|
| 26 |
+
# LoRA configuration (optimized for speech adaptation)
|
| 27 |
+
lora:
|
| 28 |
+
r: 16 # Rank - balance efficiency/performance
|
| 29 |
+
alpha: 32 # Scaling factor (2x rank optimal)
|
| 30 |
+
dropout: 0.1 # Regularization
|
| 31 |
+
target_modules: # Apply to attention matrices only
|
| 32 |
+
- "q_proj" # Query projection
|
| 33 |
+
- "k_proj" # Key projection
|
| 34 |
+
- "v_proj" # Value projection
|
| 35 |
+
- "o_proj" # Output projection
|
| 36 |
+
bias: "none" # Don't adapt bias terms
|
| 37 |
+
task_type: "CAUSAL_LM"
|
| 38 |
+
inference_mode: false
|
| 39 |
+
|
| 40 |
+
# Training Stage I: Speech-to-Text (Following LLaMA-Omni2)
|
| 41 |
+
stage1:
|
| 42 |
+
# VALIDAÇÃO MÍNIMA (para testes rápidos)
|
| 43 |
+
minimal_validation:
|
| 44 |
+
epochs: 1 # Apenas 1 epoch para teste
|
| 45 |
+
batch_size: 4 # Batch pequeno para rapidez
|
| 46 |
+
max_steps: 50 # Máximo 50 steps
|
| 47 |
+
|
| 48 |
+
# TREINAMENTO COMPLETO
|
| 49 |
+
full_training:
|
| 50 |
+
epochs: 3 # 3 epochs como no paper
|
| 51 |
+
batch_size: 32 # Batch size otimizado
|
| 52 |
+
gradient_accumulation_steps: 1
|
| 53 |
+
|
| 54 |
+
# Learning rates (different for different components)
|
| 55 |
+
learning_rates:
|
| 56 |
+
speech_projector: 5e-4 # Higher LR for projector
|
| 57 |
+
lora: 5e-5 # Lower LR for LoRA adapters
|
| 58 |
+
|
| 59 |
+
# Optimizer
|
| 60 |
+
optimizer: "adamw"
|
| 61 |
+
weight_decay: 0.01
|
| 62 |
+
beta1: 0.9
|
| 63 |
+
beta2: 0.999
|
| 64 |
+
eps: 1e-8
|
| 65 |
+
|
| 66 |
+
# Scheduler
|
| 67 |
+
scheduler: "cosine"
|
| 68 |
+
warmup_ratio: 0.03 # 3% warmup as in paper
|
| 69 |
+
min_lr_ratio: 0.1
|
| 70 |
+
|
| 71 |
+
# Regularization
|
| 72 |
+
max_grad_norm: 1.0
|
| 73 |
+
label_smoothing: 0.1
|
| 74 |
+
|
| 75 |
+
# Logging & Saving
|
| 76 |
+
logging_steps: 10
|
| 77 |
+
eval_steps: 100
|
| 78 |
+
save_steps: 500
|
| 79 |
+
save_total_limit: 3
|
| 80 |
+
|
| 81 |
+
# Early stopping
|
| 82 |
+
early_stopping_patience: 5
|
| 83 |
+
metric_for_best_model: "eval_loss"
|
| 84 |
+
|
| 85 |
+
# Training Stage II: Speech-to-Speech Enhancement (Optional)
|
| 86 |
+
stage2:
|
| 87 |
+
epochs: 1
|
| 88 |
+
batch_size: 32
|
| 89 |
+
learning_rate: 1e-3
|
| 90 |
+
|
| 91 |
+
# Freeze everything except TTS components
|
| 92 |
+
freeze_components:
|
| 93 |
+
- "whisper_encoder"
|
| 94 |
+
- "speech_projector"
|
| 95 |
+
- "qwen3_base"
|
| 96 |
+
- "lora_adapters"
|
| 97 |
+
|
| 98 |
+
# Dataset configuration
|
| 99 |
+
dataset:
|
| 100 |
+
# Common Voice 22 Portuguese - CAMINHO ATUALIZADO E ORGANIZADO
|
| 101 |
+
common_voice:
|
| 102 |
+
corpus_path: "/workspace/llama-omni2-compact/training/cv-corpus-22.0-2025-06-20-pt/cv-corpus-22.0-2025-06-20/pt"
|
| 103 |
+
language: "pt"
|
| 104 |
+
version: "22.0"
|
| 105 |
+
|
| 106 |
+
# MODO DE VALIDAÇÃO MÍNIMA (para testes rápidos - 130 samples total)
|
| 107 |
+
minimal_validation:
|
| 108 |
+
enabled: true # ATIVO por padrão para validação
|
| 109 |
+
max_samples:
|
| 110 |
+
train: 100 # Apenas 100 samples para teste rápido
|
| 111 |
+
validation: 20 # 20 para validação
|
| 112 |
+
test: 10 # 10 para teste final
|
| 113 |
+
max_audio_length: 10 # Áudios menores para rapidez
|
| 114 |
+
|
| 115 |
+
# MODO TREINAMENTO COMPLETO (desabilitar minimal_validation.enabled)
|
| 116 |
+
full_training:
|
| 117 |
+
max_samples: 50000 # 50K samples (escalável até 200K)
|
| 118 |
+
max_audio_length: 30 # segundos
|
| 119 |
+
sample_rate: 16000
|
| 120 |
+
|
| 121 |
+
# Configurações gerais
|
| 122 |
+
split_ratios:
|
| 123 |
+
train: 0.8
|
| 124 |
+
validation: 0.1
|
| 125 |
+
test: 0.1
|
| 126 |
+
|
| 127 |
+
# Synthetic instruction data - REMOVIDO PARA VALIDAÇÃO INICIAL
|
| 128 |
+
# instructions:
|
| 129 |
+
# file: "data/portuguese_instructions.json"
|
| 130 |
+
# augmentation:
|
| 131 |
+
# paraphrase: true
|
| 132 |
+
# noise_injection: 0.1 # 10% noise augmentation
|
| 133 |
+
# speed_perturbation: 0.15
|
| 134 |
+
|
| 135 |
+
# Data preprocessing
|
| 136 |
+
preprocessing:
|
| 137 |
+
normalize_audio: true
|
| 138 |
+
trim_silence: true
|
| 139 |
+
pad_or_trim: true
|
| 140 |
+
mel_spectrogram: true
|
| 141 |
+
|
| 142 |
+
# Evaluation metrics
|
| 143 |
+
evaluation:
|
| 144 |
+
metrics:
|
| 145 |
+
- "perplexity" # Primary metric
|
| 146 |
+
- "bleu" # Response quality
|
| 147 |
+
- "rouge" # Content overlap
|
| 148 |
+
- "semantic_similarity" # Embedding similarity
|
| 149 |
+
|
| 150 |
+
# Reference dataset for evaluation - SIMPLIFICADO
|
| 151 |
+
test_questions:
|
| 152 |
+
- "Hoje está um dia muito bonito."
|
| 153 |
+
- "Gosto de escutar música clássica."
|
| 154 |
+
- "O Brasil é um país muito diverso."
|
| 155 |
+
|
| 156 |
+
# Hardware optimization
|
| 157 |
+
hardware:
|
| 158 |
+
mixed_precision: true # Enable AMP for speed
|
| 159 |
+
gradient_checkpointing: true # Save memory
|
| 160 |
+
dataloader_num_workers: 4
|
| 161 |
+
pin_memory: true
|
| 162 |
+
|
| 163 |
+
# Memory optimization
|
| 164 |
+
max_memory_mb: 20000 # 20GB max memory usage
|
| 165 |
+
empty_cache_steps: 100 # Clear cache every N steps
|
| 166 |
+
|
| 167 |
+
# Paths
|
| 168 |
+
paths:
|
| 169 |
+
base_dir: "/workspace/llama-omni2-compact/training/qwen3-0.6b"
|
| 170 |
+
data_dir: "data"
|
| 171 |
+
checkpoints_dir: "checkpoints"
|
| 172 |
+
logs_dir: "logs"
|
| 173 |
+
results_dir: "results"
|
| 174 |
+
|
| 175 |
+
# Reproducibility
|
| 176 |
+
seed: 42
|
| 177 |
+
deterministic: true
|
| 178 |
+
|
| 179 |
+
# Monitoring
|
| 180 |
+
wandb:
|
| 181 |
+
enabled: false # Set true if using Weights & Biases
|
| 182 |
+
project: "qwen3-speech-embeddings"
|
| 183 |
+
run_name: "stage1-lora-r16"
|
| 184 |
+
|
| 185 |
+
# Debug mode
|
| 186 |
+
debug:
|
| 187 |
+
enabled: false
|
| 188 |
+
max_steps: 100 # Limit steps in debug mode
|
| 189 |
+
small_dataset: true # Use tiny dataset for debugging
|
training/qwen3-0.6b/data/prepare_cv22.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Common Voice 22 Dataset Preparation
|
| 4 |
+
===================================
|
| 5 |
+
Processes the Portuguese Common Voice dataset for speech embeddings training
|
| 6 |
+
Supports minimal validation mode for quick testing
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import tarfile
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
import numpy as np
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import logging
|
| 17 |
+
import argparse
|
| 18 |
+
from typing import List, Dict, Tuple, Optional
|
| 19 |
+
import json
|
| 20 |
+
import random
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CommonVoice22Processor:
|
| 27 |
+
"""
|
| 28 |
+
Process Common Voice 22 Portuguese dataset
|
| 29 |
+
|
| 30 |
+
Features:
|
| 31 |
+
- Extract and organize audio files
|
| 32 |
+
- Create train/validation/test splits
|
| 33 |
+
- Generate instruction-following samples
|
| 34 |
+
- Support for minimal validation mode (fast testing)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, corpus_path: str, output_dir: str, minimal_mode: bool = False):
|
| 38 |
+
self.corpus_path = Path(corpus_path)
|
| 39 |
+
self.output_dir = Path(output_dir)
|
| 40 |
+
self.minimal_mode = minimal_mode
|
| 41 |
+
|
| 42 |
+
# Dataset paths - o corpus já está extraído
|
| 43 |
+
if self.corpus_path.is_dir():
|
| 44 |
+
# Corpus já extraído
|
| 45 |
+
self.cv_extracted_path = self.corpus_path
|
| 46 |
+
else:
|
| 47 |
+
# Corpus ainda compactado (fallback)
|
| 48 |
+
self.cv_extracted_path = self.output_dir / "cv-corpus-22-pt"
|
| 49 |
+
self.processed_path = self.output_dir / "processed"
|
| 50 |
+
self.audio_dir = self.processed_path / "clips"
|
| 51 |
+
|
| 52 |
+
# Create directories
|
| 53 |
+
self.processed_path.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
self.audio_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
# Sample limits for modes
|
| 57 |
+
if minimal_mode:
|
| 58 |
+
self.max_samples = {
|
| 59 |
+
'train': 100, # Minimal for quick validation
|
| 60 |
+
'validation': 20,
|
| 61 |
+
'test': 10
|
| 62 |
+
}
|
| 63 |
+
logger.info("🧪 Minimal validation mode: 130 total samples")
|
| 64 |
+
else:
|
| 65 |
+
self.max_samples = {
|
| 66 |
+
'train': 10000, # Reasonable training set
|
| 67 |
+
'validation': 1000,
|
| 68 |
+
'test': 500
|
| 69 |
+
}
|
| 70 |
+
logger.info("📊 Full training mode: 11,500 total samples")
|
| 71 |
+
|
| 72 |
+
def extract_corpus(self):
|
| 73 |
+
"""Extract Common Voice corpus if needed"""
|
| 74 |
+
if self.cv_extracted_path.exists() and self.cv_extracted_path.is_dir():
|
| 75 |
+
logger.info(f"✅ Corpus já disponível em: {self.cv_extracted_path}")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
if not self.corpus_path.exists():
|
| 79 |
+
raise FileNotFoundError(f"Corpus not found: {self.corpus_path}")
|
| 80 |
+
|
| 81 |
+
# Se é um arquivo tar.gz, extrair
|
| 82 |
+
if self.corpus_path.is_file() and str(self.corpus_path).endswith('.tar.gz'):
|
| 83 |
+
logger.info(f"📦 Extracting corpus: {self.corpus_path}")
|
| 84 |
+
|
| 85 |
+
with tarfile.open(self.corpus_path, 'r:gz') as tar:
|
| 86 |
+
# Extract to parent directory so we get cv-corpus-22-pt folder
|
| 87 |
+
tar.extractall(path=self.output_dir)
|
| 88 |
+
|
| 89 |
+
logger.info(f"✅ Corpus extracted to {self.cv_extracted_path}")
|
| 90 |
+
else:
|
| 91 |
+
logger.info(f"✅ Using pre-extracted corpus at: {self.cv_extracted_path}")
|
| 92 |
+
|
| 93 |
+
def load_metadata(self) -> pd.DataFrame:
|
| 94 |
+
"""Load and process Common Voice metadata"""
|
| 95 |
+
tsv_files = {
|
| 96 |
+
'train': self.cv_extracted_path / 'train.tsv',
|
| 97 |
+
'dev': self.cv_extracted_path / 'dev.tsv',
|
| 98 |
+
'test': self.cv_extracted_path / 'test.tsv'
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
all_data = []
|
| 102 |
+
|
| 103 |
+
for split, tsv_path in tsv_files.items():
|
| 104 |
+
if not tsv_path.exists():
|
| 105 |
+
logger.warning(f"⚠️ {tsv_path} not found, skipping")
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
df = pd.read_csv(tsv_path, sep='\t')
|
| 109 |
+
df['split'] = 'validation' if split == 'dev' else split
|
| 110 |
+
all_data.append(df)
|
| 111 |
+
|
| 112 |
+
logger.info(f"📊 {split}: {len(df)} samples")
|
| 113 |
+
|
| 114 |
+
if not all_data:
|
| 115 |
+
raise FileNotFoundError("No TSV files found in corpus")
|
| 116 |
+
|
| 117 |
+
combined_df = pd.concat(all_data, ignore_index=True)
|
| 118 |
+
|
| 119 |
+
# Filter out samples without audio or text
|
| 120 |
+
combined_df = combined_df.dropna(subset=['path', 'sentence'])
|
| 121 |
+
|
| 122 |
+
logger.info(f"📊 Total samples: {len(combined_df)}")
|
| 123 |
+
return combined_df
|
| 124 |
+
|
| 125 |
+
def create_instruction_samples(self, df: pd.DataFrame) -> List[Dict]:
|
| 126 |
+
"""Convert Common Voice samples to simple transcription format"""
|
| 127 |
+
# SIMPLIFICADO: Apenas transcrição básica para validação inicial
|
| 128 |
+
instruction_templates = [
|
| 129 |
+
"Repita o que eu disse.",
|
| 130 |
+
"O que você ouviu?",
|
| 131 |
+
"Transcreva o que foi falado."
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
samples = []
|
| 135 |
+
|
| 136 |
+
for _, row in df.iterrows():
|
| 137 |
+
# Audio file path (relative to clips directory)
|
| 138 |
+
audio_path = self.cv_extracted_path / 'clips' / row['path']
|
| 139 |
+
|
| 140 |
+
# Skip if audio doesn't exist
|
| 141 |
+
if not audio_path.exists():
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
# Create instruction sample
|
| 145 |
+
instruction = random.choice(instruction_templates)
|
| 146 |
+
response = row['sentence'].strip()
|
| 147 |
+
|
| 148 |
+
sample = {
|
| 149 |
+
'audio_path': str(audio_path),
|
| 150 |
+
'instruction': instruction,
|
| 151 |
+
'response': response,
|
| 152 |
+
'split': row['split'],
|
| 153 |
+
'duration': row.get('duration', 0), # Duration in seconds
|
| 154 |
+
'up_votes': row.get('up_votes', 0),
|
| 155 |
+
'down_votes': row.get('down_votes', 0)
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
samples.append(sample)
|
| 159 |
+
|
| 160 |
+
return samples
|
| 161 |
+
|
| 162 |
+
def filter_and_sample(self, samples: List[Dict]) -> Dict[str, List[Dict]]:
|
| 163 |
+
"""Filter samples and create splits with size limits"""
|
| 164 |
+
# Filter by quality (more up_votes than down_votes)
|
| 165 |
+
quality_samples = [
|
| 166 |
+
s for s in samples
|
| 167 |
+
if s['up_votes'] >= s['down_votes'] and s['duration'] > 0
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
logger.info(f"📊 Quality filtered: {len(quality_samples)} samples")
|
| 171 |
+
|
| 172 |
+
# Group by split
|
| 173 |
+
split_samples = {
|
| 174 |
+
'train': [],
|
| 175 |
+
'validation': [],
|
| 176 |
+
'test': []
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
for sample in quality_samples:
|
| 180 |
+
split = sample['split']
|
| 181 |
+
if split in split_samples:
|
| 182 |
+
split_samples[split].append(sample)
|
| 183 |
+
|
| 184 |
+
# Sample according to limits
|
| 185 |
+
for split, samples_list in split_samples.items():
|
| 186 |
+
max_samples = self.max_samples.get(split, len(samples_list))
|
| 187 |
+
|
| 188 |
+
if len(samples_list) > max_samples:
|
| 189 |
+
# Randomly sample
|
| 190 |
+
samples_list = random.sample(samples_list, max_samples)
|
| 191 |
+
split_samples[split] = samples_list
|
| 192 |
+
|
| 193 |
+
logger.info(f"📊 {split}: {len(samples_list)} samples (limit: {max_samples})")
|
| 194 |
+
|
| 195 |
+
return split_samples
|
| 196 |
+
|
| 197 |
+
def copy_audio_files(self, split_samples: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
|
| 198 |
+
"""Copy audio files to processed directory and update paths"""
|
| 199 |
+
logger.info("📂 Copying audio files...")
|
| 200 |
+
|
| 201 |
+
all_samples = []
|
| 202 |
+
for samples_list in split_samples.values():
|
| 203 |
+
all_samples.extend(samples_list)
|
| 204 |
+
|
| 205 |
+
for sample in tqdm(all_samples, desc="Copying audio"):
|
| 206 |
+
old_path = Path(sample['audio_path'])
|
| 207 |
+
new_path = self.audio_dir / old_path.name
|
| 208 |
+
|
| 209 |
+
# Copy audio file if not exists
|
| 210 |
+
if not new_path.exists():
|
| 211 |
+
try:
|
| 212 |
+
# Load and save audio (also validates format)
|
| 213 |
+
audio, sr = sf.read(str(old_path))
|
| 214 |
+
sf.write(str(new_path), audio, sr)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.warning(f"⚠️ Failed to copy {old_path.name}: {e}")
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
# Update path in sample
|
| 220 |
+
sample['audio_path'] = str(new_path)
|
| 221 |
+
|
| 222 |
+
return split_samples
|
| 223 |
+
|
| 224 |
+
def save_processed_data(self, split_samples: Dict[str, List[Dict]]):
|
| 225 |
+
"""Save processed samples to JSON files"""
|
| 226 |
+
logger.info("💾 Saving processed data...")
|
| 227 |
+
|
| 228 |
+
for split, samples_list in split_samples.items():
|
| 229 |
+
output_file = self.processed_path / f"{split}_samples.json"
|
| 230 |
+
|
| 231 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 232 |
+
json.dump(samples_list, f, ensure_ascii=False, indent=2)
|
| 233 |
+
|
| 234 |
+
logger.info(f"✅ {split}: {len(samples_list)} samples → {output_file}")
|
| 235 |
+
|
| 236 |
+
# Create summary
|
| 237 |
+
summary = {
|
| 238 |
+
'total_samples': sum(len(samples) for samples in split_samples.values()),
|
| 239 |
+
'splits': {split: len(samples) for split, samples in split_samples.items()},
|
| 240 |
+
'audio_dir': str(self.audio_dir),
|
| 241 |
+
'minimal_mode': self.minimal_mode,
|
| 242 |
+
'instruction_templates_count': 8
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
summary_file = self.processed_path / "dataset_summary.json"
|
| 246 |
+
with open(summary_file, 'w') as f:
|
| 247 |
+
json.dump(summary, f, indent=2)
|
| 248 |
+
|
| 249 |
+
logger.info(f"📊 Summary saved: {summary_file}")
|
| 250 |
+
|
| 251 |
+
def create_sample_test(self):
|
| 252 |
+
"""Create a simple test sample for immediate validation"""
|
| 253 |
+
test_sample = {
|
| 254 |
+
'audio_path': 'dummy_audio.wav',
|
| 255 |
+
'instruction': 'Qual foi a frase que eu disse?',
|
| 256 |
+
'response': 'Esta é uma frase de teste.',
|
| 257 |
+
'split': 'test'
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
# Create dummy audio file (1 second of silence)
|
| 261 |
+
dummy_audio = np.zeros(16000) # 1 second at 16kHz
|
| 262 |
+
dummy_path = self.processed_path / 'dummy_audio.wav'
|
| 263 |
+
sf.write(str(dummy_path), dummy_audio, 16000)
|
| 264 |
+
|
| 265 |
+
test_sample['audio_path'] = str(dummy_path)
|
| 266 |
+
|
| 267 |
+
# Save test sample
|
| 268 |
+
test_file = self.processed_path / 'quick_test.json'
|
| 269 |
+
with open(test_file, 'w', encoding='utf-8') as f:
|
| 270 |
+
json.dump([test_sample], f, ensure_ascii=False, indent=2)
|
| 271 |
+
|
| 272 |
+
logger.info(f"🧪 Quick test sample: {test_file}")
|
| 273 |
+
return test_file
|
| 274 |
+
|
| 275 |
+
def process(self):
|
| 276 |
+
"""Main processing pipeline"""
|
| 277 |
+
logger.info("🚀 Starting Common Voice 22 processing...")
|
| 278 |
+
|
| 279 |
+
# Step 1: Extract corpus
|
| 280 |
+
self.extract_corpus()
|
| 281 |
+
|
| 282 |
+
# Step 2: Load metadata
|
| 283 |
+
df = self.load_metadata()
|
| 284 |
+
|
| 285 |
+
# Step 3: Create instruction samples
|
| 286 |
+
logger.info("🎯 Creating instruction-following samples...")
|
| 287 |
+
samples = self.create_instruction_samples(df)
|
| 288 |
+
|
| 289 |
+
# Step 4: Filter and sample
|
| 290 |
+
split_samples = self.filter_and_sample(samples)
|
| 291 |
+
|
| 292 |
+
# Step 5: Copy audio files
|
| 293 |
+
split_samples = self.copy_audio_files(split_samples)
|
| 294 |
+
|
| 295 |
+
# Step 6: Save processed data
|
| 296 |
+
self.save_processed_data(split_samples)
|
| 297 |
+
|
| 298 |
+
# Step 7: Create quick test sample
|
| 299 |
+
quick_test = self.create_sample_test()
|
| 300 |
+
|
| 301 |
+
logger.info("✅ Common Voice 22 processing completed!")
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
'processed_path': self.processed_path,
|
| 305 |
+
'splits': {split: len(samples) for split, samples in split_samples.items()},
|
| 306 |
+
'quick_test': quick_test
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def main():
|
| 311 |
+
parser = argparse.ArgumentParser(description="Process Common Voice 22 Portuguese")
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
'--corpus-path',
|
| 314 |
+
type=str,
|
| 315 |
+
default='/workspace/llama-omni2-compact/training/cv-corpus-22.0-2025-06-20-pt/cv-corpus-22.0-2025-06-20/pt',
|
| 316 |
+
help='Path to Common Voice corpus directory (or tar.gz file)'
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
'--output-dir',
|
| 320 |
+
type=str,
|
| 321 |
+
default='/workspace/llama-omni2-compact/training/qwen3-0.6b/data',
|
| 322 |
+
help='Output directory for processed data'
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
'--minimal',
|
| 326 |
+
action='store_true',
|
| 327 |
+
help='Minimal mode for quick validation (130 samples)'
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
args = parser.parse_args()
|
| 331 |
+
|
| 332 |
+
# Setup logging
|
| 333 |
+
logging.basicConfig(
|
| 334 |
+
level=logging.INFO,
|
| 335 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Process dataset
|
| 339 |
+
processor = CommonVoice22Processor(
|
| 340 |
+
corpus_path=args.corpus_path,
|
| 341 |
+
output_dir=args.output_dir,
|
| 342 |
+
minimal_mode=args.minimal
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
results = processor.process()
|
| 347 |
+
|
| 348 |
+
print("\n" + "="*60)
|
| 349 |
+
print("📊 PROCESSING COMPLETED")
|
| 350 |
+
print("="*60)
|
| 351 |
+
print(f"📁 Data directory: {results['processed_path']}")
|
| 352 |
+
print(f"🧪 Quick test: {results['quick_test']}")
|
| 353 |
+
print("\nSplit distribution:")
|
| 354 |
+
for split, count in results['splits'].items():
|
| 355 |
+
print(f" • {split}: {count} samples")
|
| 356 |
+
print("="*60)
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
logger.error(f"❌ Processing failed: {e}")
|
| 360 |
+
sys.exit(1)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
main()
|
training/qwen3-0.6b/data/synthetic_samples.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gerador de Samples Sintéticos para Treinamento
|
| 4 |
+
==============================================
|
| 5 |
+
Cria samples sintéticos de instrução-resposta em português
|
| 6 |
+
para complementar o dataset do Common Voice
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
from typing import List, Dict
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PortugueseInstructionGenerator:
|
| 18 |
+
"""
|
| 19 |
+
Gera instruções sintéticas em português para treinamento de embeddings de fala
|
| 20 |
+
|
| 21 |
+
Categorias:
|
| 22 |
+
- Perguntas factuais
|
| 23 |
+
- Pedidos de repetição/transcrição
|
| 24 |
+
- Comandos simples
|
| 25 |
+
- Perguntas sobre conhecimento geral
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
# Templates de instrução por categoria
|
| 30 |
+
self.instruction_templates = {
|
| 31 |
+
'transcription': [
|
| 32 |
+
"Qual foi a frase que eu disse?",
|
| 33 |
+
"O que você ouviu?",
|
| 34 |
+
"Transcreva o que foi falado.",
|
| 35 |
+
"Repita o que eu disse.",
|
| 36 |
+
"Qual é o conteúdo desta gravação?",
|
| 37 |
+
"O que está sendo dito no áudio?",
|
| 38 |
+
"Identifique a frase falada.",
|
| 39 |
+
"Qual a transcrição deste áudio?",
|
| 40 |
+
"Me diga o que você escutou.",
|
| 41 |
+
"Reproduza a frase que falei."
|
| 42 |
+
],
|
| 43 |
+
|
| 44 |
+
'questions': [
|
| 45 |
+
"Responda a pergunta que fiz.",
|
| 46 |
+
"Qual é a resposta para minha pergunta?",
|
| 47 |
+
"Me ajude com esta questão.",
|
| 48 |
+
"Você pode responder isso?",
|
| 49 |
+
"O que você acha sobre o que perguntei?",
|
| 50 |
+
"Forneça uma resposta para minha dúvida.",
|
| 51 |
+
"Explique a resposta desta pergunta.",
|
| 52 |
+
"Como você responderia a isso?"
|
| 53 |
+
],
|
| 54 |
+
|
| 55 |
+
'general': [
|
| 56 |
+
"Processe este áudio e me responda.",
|
| 57 |
+
"Analise o que eu disse.",
|
| 58 |
+
"Interprete minha mensagem de voz.",
|
| 59 |
+
"Compreenda e responda ao áudio.",
|
| 60 |
+
"O que posso fazer com relação ao que falei?",
|
| 61 |
+
"Ajude-me baseado no que disse.",
|
| 62 |
+
"Forneça uma resposta apropriada.",
|
| 63 |
+
"Como você interpretaria isso?"
|
| 64 |
+
]
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
# Frases exemplo em português (Common Voice style)
|
| 68 |
+
self.sample_sentences = [
|
| 69 |
+
"Hoje está um dia muito bonito.",
|
| 70 |
+
"Gosto de escutar música clássica.",
|
| 71 |
+
"O Brasil é um país muito diverso.",
|
| 72 |
+
"A tecnologia avança rapidamente.",
|
| 73 |
+
"Preciso comprar pão na padaria.",
|
| 74 |
+
"Meus amigos chegaram cedo.",
|
| 75 |
+
"O filme foi muito interessante.",
|
| 76 |
+
"Vou viajar nas férias de verão.",
|
| 77 |
+
"A chuva começou a cair forte.",
|
| 78 |
+
"Estou aprendendo uma nova língua.",
|
| 79 |
+
"O gato subiu no telhado.",
|
| 80 |
+
"Adoro cozinhar comida italiana.",
|
| 81 |
+
"O trânsito estava muito intenso.",
|
| 82 |
+
"Encontrei um livro fascinante.",
|
| 83 |
+
"A reunião terminou mais cedo."
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# Perguntas factuais
|
| 87 |
+
self.factual_qa = [
|
| 88 |
+
{
|
| 89 |
+
"question": "Qual é a capital do Brasil?",
|
| 90 |
+
"answer": "A capital do Brasil é Brasília."
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"question": "Quantos estados tem o Brasil?",
|
| 94 |
+
"answer": "O Brasil tem 26 estados e 1 distrito federal."
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"question": "Qual é o maior país da América do Sul?",
|
| 98 |
+
"answer": "O Brasil é o maior país da América do Sul."
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"question": "Em que continente fica o Brasil?",
|
| 102 |
+
"answer": "O Brasil fica na América do Sul."
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"question": "Qual é a moeda do Brasil?",
|
| 106 |
+
"answer": "A moeda do Brasil é o Real."
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"question": "Quantos dias tem uma semana?",
|
| 110 |
+
"answer": "Uma semana tem sete dias."
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"question": "Qual é o maior oceano do mundo?",
|
| 114 |
+
"answer": "O maior oceano do mundo é o Pacífico."
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"question": "Quantas horas tem um dia?",
|
| 118 |
+
"answer": "Um dia tem vinte e quatro horas."
|
| 119 |
+
}
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
def generate_transcription_samples(self, count: int = 50) -> List[Dict]:
|
| 123 |
+
"""Gera samples de transcrição"""
|
| 124 |
+
samples = []
|
| 125 |
+
|
| 126 |
+
for _ in range(count):
|
| 127 |
+
sentence = random.choice(self.sample_sentences)
|
| 128 |
+
instruction = random.choice(self.instruction_templates['transcription'])
|
| 129 |
+
|
| 130 |
+
sample = {
|
| 131 |
+
'type': 'transcription',
|
| 132 |
+
'instruction': instruction,
|
| 133 |
+
'audio_content': sentence, # O que seria falado no áudio
|
| 134 |
+
'response': sentence, # Resposta esperada
|
| 135 |
+
'category': 'synthetic'
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
samples.append(sample)
|
| 139 |
+
|
| 140 |
+
return samples
|
| 141 |
+
|
| 142 |
+
def generate_qa_samples(self, count: int = 30) -> List[Dict]:
|
| 143 |
+
"""Gera samples de perguntas e respostas"""
|
| 144 |
+
samples = []
|
| 145 |
+
|
| 146 |
+
for _ in range(count):
|
| 147 |
+
qa = random.choice(self.factual_qa)
|
| 148 |
+
instruction = random.choice(self.instruction_templates['questions'])
|
| 149 |
+
|
| 150 |
+
sample = {
|
| 151 |
+
'type': 'qa',
|
| 152 |
+
'instruction': instruction,
|
| 153 |
+
'audio_content': qa['question'], # Pergunta falada
|
| 154 |
+
'response': qa['answer'], # Resposta esperada
|
| 155 |
+
'category': 'synthetic'
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
samples.append(sample)
|
| 159 |
+
|
| 160 |
+
return samples
|
| 161 |
+
|
| 162 |
+
def generate_general_samples(self, count: int = 20) -> List[Dict]:
|
| 163 |
+
"""Gera samples gerais"""
|
| 164 |
+
samples = []
|
| 165 |
+
|
| 166 |
+
general_pairs = [
|
| 167 |
+
{
|
| 168 |
+
'content': 'Quero saber as horas.',
|
| 169 |
+
'response': 'Para saber as horas, você pode olhar no relógio ou perguntar a alguém.'
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
'content': 'Como está o tempo hoje?',
|
| 173 |
+
'response': 'Para saber como está o tempo, você pode olhar pela janela ou verificar a previsão meteorológica.'
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
'content': 'Preciso de ajuda com uma tarefa.',
|
| 177 |
+
'response': 'Ficarei feliz em ajudar. Pode me explicar qual é a tarefa?'
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
'content': 'Estou com fome.',
|
| 181 |
+
'response': 'Que tal preparar algo para comer ou pedir uma refeição?'
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
'content': 'Não entendi a explicação.',
|
| 185 |
+
'response': 'Posso tentar explicar de forma mais clara. Qual parte você gostaria que eu esclarecesse?'
|
| 186 |
+
}
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
for _ in range(count):
|
| 190 |
+
pair = random.choice(general_pairs)
|
| 191 |
+
instruction = random.choice(self.instruction_templates['general'])
|
| 192 |
+
|
| 193 |
+
sample = {
|
| 194 |
+
'type': 'general',
|
| 195 |
+
'instruction': instruction,
|
| 196 |
+
'audio_content': pair['content'],
|
| 197 |
+
'response': pair['response'],
|
| 198 |
+
'category': 'synthetic'
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
samples.append(sample)
|
| 202 |
+
|
| 203 |
+
return samples
|
| 204 |
+
|
| 205 |
+
def generate_complete_dataset(self,
|
| 206 |
+
transcription_count: int = 50,
|
| 207 |
+
qa_count: int = 30,
|
| 208 |
+
general_count: int = 20) -> List[Dict]:
|
| 209 |
+
"""Gera dataset completo com todos os tipos"""
|
| 210 |
+
logger.info("🎯 Gerando samples sintéticos...")
|
| 211 |
+
|
| 212 |
+
samples = []
|
| 213 |
+
|
| 214 |
+
# Gerar cada tipo
|
| 215 |
+
samples.extend(self.generate_transcription_samples(transcription_count))
|
| 216 |
+
samples.extend(self.generate_qa_samples(qa_count))
|
| 217 |
+
samples.extend(self.generate_general_samples(general_count))
|
| 218 |
+
|
| 219 |
+
# Embaralhar
|
| 220 |
+
random.shuffle(samples)
|
| 221 |
+
|
| 222 |
+
logger.info(f"✅ {len(samples)} samples sintéticos gerados")
|
| 223 |
+
logger.info(f" • Transcrição: {transcription_count}")
|
| 224 |
+
logger.info(f" • Q&A: {qa_count}")
|
| 225 |
+
logger.info(f" • Geral: {general_count}")
|
| 226 |
+
|
| 227 |
+
return samples
|
| 228 |
+
|
| 229 |
+
def save_samples(self, samples: List[Dict], output_path: str):
|
| 230 |
+
"""Salva samples em arquivo JSON"""
|
| 231 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 232 |
+
json.dump(samples, f, ensure_ascii=False, indent=2)
|
| 233 |
+
|
| 234 |
+
logger.info(f"💾 Samples salvos em: {output_path}")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
"""Gera e salva samples sintéticos"""
|
| 239 |
+
import argparse
|
| 240 |
+
from pathlib import Path
|
| 241 |
+
|
| 242 |
+
parser = argparse.ArgumentParser(description="Gerar samples sintéticos em português")
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
'--output',
|
| 245 |
+
type=str,
|
| 246 |
+
default='portuguese_instructions.json',
|
| 247 |
+
help='Arquivo de saída'
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
'--transcription',
|
| 251 |
+
type=int,
|
| 252 |
+
default=50,
|
| 253 |
+
help='Número de samples de transcrição'
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
'--qa',
|
| 257 |
+
type=int,
|
| 258 |
+
default=30,
|
| 259 |
+
help='Número de samples de Q&A'
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
'--general',
|
| 263 |
+
type=int,
|
| 264 |
+
default=20,
|
| 265 |
+
help='Número de samples gerais'
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
args = parser.parse_args()
|
| 269 |
+
|
| 270 |
+
# Setup logging
|
| 271 |
+
logging.basicConfig(level=logging.INFO)
|
| 272 |
+
|
| 273 |
+
# Gerar samples
|
| 274 |
+
generator = PortugueseInstructionGenerator()
|
| 275 |
+
samples = generator.generate_complete_dataset(
|
| 276 |
+
transcription_count=args.transcription,
|
| 277 |
+
qa_count=args.qa,
|
| 278 |
+
general_count=args.general
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Salvar
|
| 282 |
+
generator.save_samples(samples, args.output)
|
| 283 |
+
|
| 284 |
+
print(f"\n✅ {len(samples)} samples sintéticos criados em {args.output}")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
main()
|
training/qwen3-0.6b/requirements.txt
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎤 Qwen3-0.6B Speech Embeddings Training Requirements
|
| 2 |
+
# Based on LLaMA-Omni2 + LoRA-Whisper methodologies
|
| 3 |
+
|
| 4 |
+
# Core dependencies
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
torchaudio>=2.0.0
|
| 7 |
+
transformers>=4.51.0 # Latest for Qwen3 support
|
| 8 |
+
tokenizers>=0.21.0
|
| 9 |
+
|
| 10 |
+
# Speech processing
|
| 11 |
+
openai-whisper>=20231117 # Whisper large-v3
|
| 12 |
+
soundfile>=0.12.0
|
| 13 |
+
librosa>=0.10.0
|
| 14 |
+
scipy>=1.11.0
|
| 15 |
+
|
| 16 |
+
# LoRA and parameter-efficient fine-tuning
|
| 17 |
+
peft>=0.10.0 # Parameter-Efficient Fine-Tuning
|
| 18 |
+
bitsandbytes>=0.42.0 # Quantization support
|
| 19 |
+
|
| 20 |
+
# Dataset processing
|
| 21 |
+
datasets>=2.16.0 # HuggingFace datasets
|
| 22 |
+
accelerate>=0.25.0 # Distributed training
|
| 23 |
+
evaluate>=0.4.0 # Evaluation metrics
|
| 24 |
+
|
| 25 |
+
# Model utilities
|
| 26 |
+
safetensors>=0.4.0 # Model serialization
|
| 27 |
+
huggingface-hub>=0.20.0 # Model hub integration
|
| 28 |
+
|
| 29 |
+
# Training utilities
|
| 30 |
+
tqdm>=4.66.0 # Progress bars
|
| 31 |
+
wandb>=0.16.0 # Experiment tracking (optional)
|
| 32 |
+
tensorboard>=2.15.0 # TensorBoard logging (optional)
|
| 33 |
+
|
| 34 |
+
# Data processing
|
| 35 |
+
pandas>=2.1.0
|
| 36 |
+
numpy>=1.24.0
|
| 37 |
+
PyYAML>=6.0
|
| 38 |
+
|
| 39 |
+
# Audio augmentation (optional)
|
| 40 |
+
audiomentations>=0.35.0 # Audio data augmentation
|
| 41 |
+
pyroomacoustics>=0.7.0 # Room acoustics simulation
|
| 42 |
+
|
| 43 |
+
# Portuguese NLP (for instruction processing)
|
| 44 |
+
nltk>=3.8
|
| 45 |
+
spacy>=3.7.0
|
| 46 |
+
# python -m spacy download pt_core_news_sm # Portuguese model
|
| 47 |
+
|
| 48 |
+
# Evaluation metrics
|
| 49 |
+
sacrebleu>=2.3.0 # BLEU score calculation
|
| 50 |
+
rouge-score>=0.1.0 # ROUGE metrics
|
| 51 |
+
sentence-transformers>=2.2.0 # Semantic similarity
|
| 52 |
+
|
| 53 |
+
# Utilities
|
| 54 |
+
colorlog>=6.8.0 # Colored logging
|
| 55 |
+
psutil>=5.9.0 # System monitoring
|
| 56 |
+
GPUtil>=1.4.0 # GPU monitoring
|
| 57 |
+
|
| 58 |
+
# Optional dependencies for advanced features
|
| 59 |
+
# Uncomment if needed:
|
| 60 |
+
|
| 61 |
+
# Speech synthesis (for Stage II)
|
| 62 |
+
# TTS>=0.22.0 # Coqui TTS
|
| 63 |
+
# espeak-ng # Text-to-speech backend
|
| 64 |
+
|
| 65 |
+
# Advanced audio processing
|
| 66 |
+
# pyannote.audio>=3.1.0 # Speaker diarization
|
| 67 |
+
# speechbrain>=0.5.0 # Speech processing toolkit
|
| 68 |
+
|
| 69 |
+
# Distributed training
|
| 70 |
+
# deepspeed>=0.12.0 # DeepSpeed optimization
|
| 71 |
+
# fairscale>=0.4.0 # Facebook's scaling library
|
| 72 |
+
|
| 73 |
+
# Development tools
|
| 74 |
+
pytest>=7.4.0 # Testing
|
| 75 |
+
black>=23.0.0 # Code formatting
|
| 76 |
+
flake8>=6.0.0 # Linting
|
| 77 |
+
|
| 78 |
+
# Platform-specific installations:
|
| 79 |
+
#
|
| 80 |
+
# For CUDA 11.8:
|
| 81 |
+
# pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
| 82 |
+
#
|
| 83 |
+
# For CUDA 12.1:
|
| 84 |
+
# pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 85 |
+
#
|
| 86 |
+
# For CPU only:
|
| 87 |
+
# pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
|
| 88 |
+
|
| 89 |
+
# Installation notes:
|
| 90 |
+
# 1. Install PyTorch first with correct CUDA version
|
| 91 |
+
# 2. Install Whisper: pip install -U openai-whisper
|
| 92 |
+
# 3. Download Portuguese spaCy model: python -m spacy download pt_core_news_sm
|
| 93 |
+
# 4. For Common Voice dataset: pip install datasets[audio]
|
| 94 |
+
# 5. Optional: Install ffmpeg for audio processing: apt-get install ffmpeg
|
training/qwen3-0.6b/scripts/quick_validation.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick Validation Script
|
| 4 |
+
=======================
|
| 5 |
+
Minimal training setup for rapid validation of the speech embeddings pipeline
|
| 6 |
+
Tests if the basic architecture works before full training
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import yaml
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import logging
|
| 16 |
+
import json
|
| 17 |
+
import numpy as np
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import time
|
| 20 |
+
import whisper
|
| 21 |
+
|
| 22 |
+
# Add project root to path
|
| 23 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 24 |
+
|
| 25 |
+
from models.speech_adapter import SpeechAdapterModule
|
| 26 |
+
from models.lora_qwen3 import LoRAQwen3ForSpeech
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class QuickValidator:
|
| 32 |
+
"""
|
| 33 |
+
Quick validation of speech embeddings pipeline
|
| 34 |
+
|
| 35 |
+
Tests:
|
| 36 |
+
1. Model loading (Whisper + Speech Adapter + LoRA Qwen3)
|
| 37 |
+
2. Forward pass with dummy data
|
| 38 |
+
3. Training step with minimal data
|
| 39 |
+
4. Inference with speech input
|
| 40 |
+
5. Basic functionality verification
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, config_path: str = None):
|
| 44 |
+
# Setup logging
|
| 45 |
+
logging.basicConfig(
|
| 46 |
+
level=logging.INFO,
|
| 47 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load configuration
|
| 51 |
+
if config_path is None:
|
| 52 |
+
config_path = Path(__file__).parent.parent / "config" / "training_config.yaml"
|
| 53 |
+
|
| 54 |
+
with open(config_path, 'r') as f:
|
| 55 |
+
self.config = yaml.safe_load(f)
|
| 56 |
+
|
| 57 |
+
# Override for quick validation
|
| 58 |
+
self.config["debug"]["enabled"] = True
|
| 59 |
+
self.config["stage1"]["epochs"] = 1
|
| 60 |
+
self.config["stage1"]["batch_size"] = 2
|
| 61 |
+
|
| 62 |
+
self.device = self.config["model"]["device"]
|
| 63 |
+
if not torch.cuda.is_available():
|
| 64 |
+
self.device = "cpu"
|
| 65 |
+
logger.warning("⚠️ CUDA not available, using CPU")
|
| 66 |
+
|
| 67 |
+
# Initialize components
|
| 68 |
+
self.whisper_model = None
|
| 69 |
+
self.speech_adapter = None
|
| 70 |
+
self.lora_qwen3 = None
|
| 71 |
+
|
| 72 |
+
logger.info("🧪 Quick Validator initialized")
|
| 73 |
+
|
| 74 |
+
def test_whisper_loading(self) -> bool:
|
| 75 |
+
"""Test 1: Load Whisper model"""
|
| 76 |
+
logger.info("📦 Test 1: Loading Whisper...")
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
start_time = time.time()
|
| 80 |
+
|
| 81 |
+
# Try to load from local file first
|
| 82 |
+
whisper_path = self.config["whisper"].get("model_path")
|
| 83 |
+
if whisper_path and os.path.exists(whisper_path):
|
| 84 |
+
self.whisper_model = whisper.load_model(whisper_path, device=self.device)
|
| 85 |
+
else:
|
| 86 |
+
self.whisper_model = whisper.load_model("large-v3", device=self.device)
|
| 87 |
+
|
| 88 |
+
load_time = time.time() - start_time
|
| 89 |
+
logger.info(f"✅ Whisper loaded in {load_time:.1f}s")
|
| 90 |
+
|
| 91 |
+
# Test basic functionality
|
| 92 |
+
dummy_audio = np.random.randn(16000 * 2).astype(np.float32)
|
| 93 |
+
mel = whisper.log_mel_spectrogram(dummy_audio, n_mels=128)
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
features = self.whisper_model.encoder(mel.unsqueeze(0).to(self.device))
|
| 97 |
+
|
| 98 |
+
logger.info(f" • Dummy audio processed: {features.shape}")
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"❌ Whisper loading failed: {e}")
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def test_speech_adapter(self) -> bool:
|
| 106 |
+
"""Test 2: Create and test Speech Adapter"""
|
| 107 |
+
logger.info("🎤 Test 2: Speech Adapter...")
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
# Create speech adapter
|
| 111 |
+
self.speech_adapter = SpeechAdapterModule(
|
| 112 |
+
whisper_model=self.whisper_model,
|
| 113 |
+
encoder_dim=self.config["speech_projector"]["input_dim"],
|
| 114 |
+
llm_dim=self.config["speech_projector"]["output_dim"],
|
| 115 |
+
k=self.config["speech_projector"]["downsample_factor"],
|
| 116 |
+
device=self.device
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
total, trainable = self.speech_adapter.get_parameter_count()
|
| 120 |
+
logger.info(f"✅ Speech Adapter created")
|
| 121 |
+
logger.info(f" • Total params: {total:,}")
|
| 122 |
+
logger.info(f" • Trainable params: {trainable:,}")
|
| 123 |
+
|
| 124 |
+
# Test forward pass
|
| 125 |
+
dummy_audio = np.random.randn(16000 * 3).astype(np.float32)
|
| 126 |
+
mel = whisper.log_mel_spectrogram(dummy_audio, n_mels=128).permute(1, 0)
|
| 127 |
+
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
output = self.speech_adapter(mel.unsqueeze(0).to(self.device))
|
| 130 |
+
|
| 131 |
+
logger.info(f" • Forward pass: {mel.shape} → {output.shape}")
|
| 132 |
+
return True
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"❌ Speech Adapter failed: {e}")
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
def test_lora_qwen3(self) -> bool:
|
| 139 |
+
"""Test 3: Load LoRA Qwen3"""
|
| 140 |
+
logger.info("🧠 Test 3: LoRA Qwen3...")
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
# Create LoRA Qwen3 with small config for testing
|
| 144 |
+
test_lora_config = {
|
| 145 |
+
"r": 4, # Small rank for testing
|
| 146 |
+
"alpha": 8,
|
| 147 |
+
"dropout": 0.1,
|
| 148 |
+
"target_modules": ["q_proj", "v_proj"], # Only 2 modules for speed
|
| 149 |
+
"bias": "none",
|
| 150 |
+
"task_type": "CAUSAL_LM"
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
self.lora_qwen3 = LoRAQwen3ForSpeech(
|
| 154 |
+
model_name=self.config["model"]["name"],
|
| 155 |
+
lora_config=test_lora_config,
|
| 156 |
+
device=self.device,
|
| 157 |
+
torch_dtype="float32"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
logger.info("✅ LoRA Qwen3 loaded")
|
| 161 |
+
|
| 162 |
+
# Test text generation
|
| 163 |
+
test_text = "What is the capital of Brazil?"
|
| 164 |
+
inputs = self.lora_qwen3.tokenizer(test_text, return_tensors="pt")
|
| 165 |
+
|
| 166 |
+
with torch.no_grad():
|
| 167 |
+
outputs = self.lora_qwen3.generate(
|
| 168 |
+
**inputs,
|
| 169 |
+
max_new_tokens=10,
|
| 170 |
+
temperature=0.7
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
response = self.lora_qwen3.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 174 |
+
logger.info(f" • Text generation test: '{test_text}' → '{response}'")
|
| 175 |
+
|
| 176 |
+
return True
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"❌ LoRA Qwen3 failed: {e}")
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
def test_speech_integration(self) -> bool:
|
| 183 |
+
"""Test 4: Speech-to-text integration"""
|
| 184 |
+
logger.info("🔗 Test 4: Speech integration...")
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
# Create dummy speech embeddings
|
| 188 |
+
batch_size, seq_len, hidden_dim = 1, 100, 1024
|
| 189 |
+
dummy_speech = torch.randn(batch_size, seq_len, hidden_dim, device=self.device)
|
| 190 |
+
|
| 191 |
+
# Create input with speech token
|
| 192 |
+
speech_text = "<speech> What did I say?"
|
| 193 |
+
inputs = self.lora_qwen3.tokenizer(speech_text, return_tensors="pt")
|
| 194 |
+
|
| 195 |
+
# Replace speech token with special index
|
| 196 |
+
speech_token_id = self.lora_qwen3.tokenizer.convert_tokens_to_ids("<speech>")
|
| 197 |
+
inputs["input_ids"][inputs["input_ids"] == speech_token_id] = self.lora_qwen3.SPEECH_TOKEN_INDEX
|
| 198 |
+
|
| 199 |
+
# Test mixed embeddings preparation
|
| 200 |
+
mixed_embeds = self.lora_qwen3.prepare_inputs_with_speech(
|
| 201 |
+
inputs["input_ids"].to(self.device),
|
| 202 |
+
dummy_speech
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
logger.info(f" • Mixed embeddings: {mixed_embeds.shape}")
|
| 206 |
+
|
| 207 |
+
# Test forward pass with mixed embeddings
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
outputs = self.lora_qwen3(inputs_embeds=mixed_embeds)
|
| 210 |
+
|
| 211 |
+
logger.info(f" • Forward pass successful: loss = {outputs.loss.item():.4f}")
|
| 212 |
+
return True
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"❌ Speech integration failed: {e}")
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
def test_end_to_end_pipeline(self) -> bool:
|
| 219 |
+
"""Test 5: Complete end-to-end pipeline"""
|
| 220 |
+
logger.info("🚀 Test 5: End-to-end pipeline...")
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
# Create realistic audio
|
| 224 |
+
duration = 2 # 2 seconds
|
| 225 |
+
sample_rate = 16000
|
| 226 |
+
dummy_audio = np.random.randn(duration * sample_rate).astype(np.float32) * 0.01
|
| 227 |
+
|
| 228 |
+
# Step 1: Audio → Mel spectrogram
|
| 229 |
+
mel = whisper.log_mel_spectrogram(dummy_audio, n_mels=128).permute(1, 0)
|
| 230 |
+
|
| 231 |
+
# Step 2: Mel → Speech embeddings
|
| 232 |
+
speech_embeddings = self.speech_adapter(mel.unsqueeze(0).to(self.device))
|
| 233 |
+
|
| 234 |
+
# Step 3: Create instruction input
|
| 235 |
+
instruction = "Repita o que eu disse."
|
| 236 |
+
inputs = self.lora_qwen3.tokenizer(
|
| 237 |
+
f"<speech> {instruction}",
|
| 238 |
+
return_tensors="pt"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Replace speech token
|
| 242 |
+
speech_token_id = self.lora_qwen3.tokenizer.convert_tokens_to_ids("<speech>")
|
| 243 |
+
inputs["input_ids"][inputs["input_ids"] == speech_token_id] = self.lora_qwen3.SPEECH_TOKEN_INDEX
|
| 244 |
+
|
| 245 |
+
# Step 4: Generate response
|
| 246 |
+
mixed_embeds = self.lora_qwen3.prepare_inputs_with_speech(
|
| 247 |
+
inputs["input_ids"].to(self.device),
|
| 248 |
+
speech_embeddings
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
outputs = self.lora_qwen3.generate(
|
| 253 |
+
inputs_embeds=mixed_embeds,
|
| 254 |
+
max_new_tokens=20,
|
| 255 |
+
temperature=0.7
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
response = self.lora_qwen3.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 259 |
+
|
| 260 |
+
logger.info(f"✅ End-to-end pipeline successful!")
|
| 261 |
+
logger.info(f" • Input audio: {duration}s")
|
| 262 |
+
logger.info(f" • Speech embeddings: {speech_embeddings.shape}")
|
| 263 |
+
logger.info(f" • Response: '{response}'")
|
| 264 |
+
|
| 265 |
+
return True
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"❌ End-to-end pipeline failed: {e}")
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
def test_minimal_training_step(self) -> bool:
|
| 272 |
+
"""Test 6: Minimal training step"""
|
| 273 |
+
logger.info("📚 Test 6: Minimal training step...")
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
# Create minimal training data
|
| 277 |
+
batch_size = 2
|
| 278 |
+
seq_len = 50
|
| 279 |
+
|
| 280 |
+
# Create dummy speech embeddings
|
| 281 |
+
speech_embeddings = torch.randn(batch_size, seq_len, 1024, device=self.device)
|
| 282 |
+
|
| 283 |
+
# Create dummy labels
|
| 284 |
+
dummy_texts = [
|
| 285 |
+
"Esta é uma frase de teste.",
|
| 286 |
+
"Outra frase para treinamento."
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
tokenized = self.lora_qwen3.tokenizer(
|
| 290 |
+
dummy_texts,
|
| 291 |
+
padding=True,
|
| 292 |
+
truncation=True,
|
| 293 |
+
max_length=64,
|
| 294 |
+
return_tensors="pt"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
input_ids = tokenized["input_ids"].to(self.device)
|
| 298 |
+
labels = input_ids.clone()
|
| 299 |
+
|
| 300 |
+
# Replace first token with speech token
|
| 301 |
+
input_ids[:, 0] = self.lora_qwen3.SPEECH_TOKEN_INDEX
|
| 302 |
+
|
| 303 |
+
# Prepare mixed embeddings
|
| 304 |
+
mixed_embeds = self.lora_qwen3.prepare_inputs_with_speech(
|
| 305 |
+
input_ids, speech_embeddings
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Training step
|
| 309 |
+
self.lora_qwen3.model.train()
|
| 310 |
+
|
| 311 |
+
outputs = self.lora_qwen3(
|
| 312 |
+
inputs_embeds=mixed_embeds,
|
| 313 |
+
labels=labels
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
loss = outputs.loss
|
| 317 |
+
loss.backward()
|
| 318 |
+
|
| 319 |
+
logger.info(f"✅ Training step successful!")
|
| 320 |
+
logger.info(f" • Batch size: {batch_size}")
|
| 321 |
+
logger.info(f" • Training loss: {loss.item():.4f}")
|
| 322 |
+
logger.info(f" • Gradients computed: ✓")
|
| 323 |
+
|
| 324 |
+
return True
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"❌ Training step failed: {e}")
|
| 328 |
+
return False
|
| 329 |
+
|
| 330 |
+
def run_validation(self) -> bool:
|
| 331 |
+
"""Run all validation tests"""
|
| 332 |
+
logger.info("\n" + "="*60)
|
| 333 |
+
logger.info("🧪 QUICK VALIDATION SUITE")
|
| 334 |
+
logger.info("="*60)
|
| 335 |
+
|
| 336 |
+
tests = [
|
| 337 |
+
("Whisper Loading", self.test_whisper_loading),
|
| 338 |
+
("Speech Adapter", self.test_speech_adapter),
|
| 339 |
+
("LoRA Qwen3", self.test_lora_qwen3),
|
| 340 |
+
("Speech Integration", self.test_speech_integration),
|
| 341 |
+
("End-to-End Pipeline", self.test_end_to_end_pipeline),
|
| 342 |
+
("Training Step", self.test_minimal_training_step)
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
results = []
|
| 346 |
+
total_time = 0
|
| 347 |
+
|
| 348 |
+
for test_name, test_func in tests:
|
| 349 |
+
logger.info(f"\n🔍 Running {test_name}...")
|
| 350 |
+
start_time = time.time()
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
success = test_func()
|
| 354 |
+
test_time = time.time() - start_time
|
| 355 |
+
total_time += test_time
|
| 356 |
+
|
| 357 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 358 |
+
logger.info(f" {status} ({test_time:.1f}s)")
|
| 359 |
+
|
| 360 |
+
results.append((test_name, success, test_time))
|
| 361 |
+
|
| 362 |
+
if not success:
|
| 363 |
+
logger.error(f"⛔ Stopping validation due to {test_name} failure")
|
| 364 |
+
break
|
| 365 |
+
|
| 366 |
+
except Exception as e:
|
| 367 |
+
logger.error(f"❌ {test_name} crashed: {e}")
|
| 368 |
+
results.append((test_name, False, time.time() - start_time))
|
| 369 |
+
break
|
| 370 |
+
|
| 371 |
+
# Summary
|
| 372 |
+
logger.info("\n" + "="*60)
|
| 373 |
+
logger.info("📊 VALIDATION SUMMARY")
|
| 374 |
+
logger.info("="*60)
|
| 375 |
+
|
| 376 |
+
passed = sum(1 for _, success, _ in results if success)
|
| 377 |
+
total = len(results)
|
| 378 |
+
|
| 379 |
+
for test_name, success, test_time in results:
|
| 380 |
+
status = "✅" if success else "❌"
|
| 381 |
+
logger.info(f"{status} {test_name:<25} ({test_time:.1f}s)")
|
| 382 |
+
|
| 383 |
+
logger.info("-" * 60)
|
| 384 |
+
logger.info(f"Total: {passed}/{total} tests passed")
|
| 385 |
+
logger.info(f"Time: {total_time:.1f}s")
|
| 386 |
+
|
| 387 |
+
if passed == len(tests):
|
| 388 |
+
logger.info("🎉 ALL TESTS PASSED - Ready for training!")
|
| 389 |
+
return True
|
| 390 |
+
else:
|
| 391 |
+
logger.info("⚠️ SOME TESTS FAILED - Fix issues before training")
|
| 392 |
+
return False
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def main():
|
| 396 |
+
import argparse
|
| 397 |
+
|
| 398 |
+
parser = argparse.ArgumentParser(description="Quick validation of speech training pipeline")
|
| 399 |
+
parser.add_argument(
|
| 400 |
+
'--config',
|
| 401 |
+
type=str,
|
| 402 |
+
help='Path to configuration file'
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
args = parser.parse_args()
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
validator = QuickValidator(config_path=args.config)
|
| 409 |
+
success = validator.run_validation()
|
| 410 |
+
|
| 411 |
+
if success:
|
| 412 |
+
print("\n🚀 Validation passed! You can now run full training:")
|
| 413 |
+
print("python scripts/train_stage1.py --config config/training_config.yaml")
|
| 414 |
+
else:
|
| 415 |
+
print("\n⚠️ Validation failed! Please fix the issues above.")
|
| 416 |
+
sys.exit(1)
|
| 417 |
+
|
| 418 |
+
except Exception as e:
|
| 419 |
+
logger.error(f"❌ Validation suite crashed: {e}")
|
| 420 |
+
sys.exit(1)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
main()
|
training/qwen3-0.6b/scripts/run_minimal_validation.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script para Validação Mínima Completa
|
| 4 |
+
=====================================
|
| 5 |
+
Executa todo o pipeline de validação mínima:
|
| 6 |
+
1. Preparação do dataset (modo mínimo)
|
| 7 |
+
2. Validação técnica da arquitetura
|
| 8 |
+
3. Treinamento mínimo (1 epoch, 50 steps)
|
| 9 |
+
4. Teste de inferência
|
| 10 |
+
|
| 11 |
+
Uso: python scripts/run_minimal_validation.py
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import subprocess
|
| 17 |
+
import yaml
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import logging
|
| 20 |
+
import time
|
| 21 |
+
|
| 22 |
+
# Setup logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MinimalValidationRunner:
|
| 31 |
+
"""
|
| 32 |
+
Executa validação completa do pipeline de treinamento
|
| 33 |
+
|
| 34 |
+
Etapas:
|
| 35 |
+
1. Verificar ambiente e dependências
|
| 36 |
+
2. Preparar dataset Common Voice (modo mínimo)
|
| 37 |
+
3. Executar validação técnica (arquitetura)
|
| 38 |
+
4. Executar treinamento mínimo (1 epoch)
|
| 39 |
+
5. Testar inferência final
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
self.base_dir = Path(__file__).parent.parent
|
| 44 |
+
self.config_path = self.base_dir / "config" / "training_config.yaml"
|
| 45 |
+
|
| 46 |
+
# Carregar configuração
|
| 47 |
+
with open(self.config_path) as f:
|
| 48 |
+
self.config = yaml.safe_load(f)
|
| 49 |
+
|
| 50 |
+
logger.info("🎤 Iniciando Validação Mínima Completa")
|
| 51 |
+
logger.info(f"📁 Diretório base: {self.base_dir}")
|
| 52 |
+
|
| 53 |
+
def check_environment(self) -> bool:
|
| 54 |
+
"""Verificar ambiente e dependências"""
|
| 55 |
+
logger.info("🔍 Verificando ambiente...")
|
| 56 |
+
|
| 57 |
+
# Verificar Python
|
| 58 |
+
python_version = sys.version_info
|
| 59 |
+
if python_version.major < 3 or python_version.minor < 8:
|
| 60 |
+
logger.error("❌ Python 3.8+ necessário")
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
# Verificar CUDA
|
| 64 |
+
try:
|
| 65 |
+
import torch
|
| 66 |
+
if torch.cuda.is_available():
|
| 67 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 68 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 69 |
+
logger.info(f"✅ GPU: {gpu_name} ({gpu_memory:.1f}GB)")
|
| 70 |
+
else:
|
| 71 |
+
logger.warning("⚠️ CUDA não disponível, usando CPU")
|
| 72 |
+
except ImportError:
|
| 73 |
+
logger.error("❌ PyTorch não instalado")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
# Verificar Common Voice corpus
|
| 77 |
+
corpus_path = Path(self.config["dataset"]["common_voice"]["corpus_path"])
|
| 78 |
+
if not corpus_path.exists():
|
| 79 |
+
logger.error(f"❌ Corpus não encontrado: {corpus_path}")
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
corpus_size_gb = corpus_path.stat().st_size / 1024**3
|
| 83 |
+
logger.info(f"✅ Common Voice corpus: {corpus_size_gb:.1f}GB")
|
| 84 |
+
|
| 85 |
+
# Verificar Whisper model
|
| 86 |
+
whisper_path = Path(self.config["whisper"]["model_path"])
|
| 87 |
+
if whisper_path.exists():
|
| 88 |
+
logger.info(f"✅ Whisper model: {whisper_path}")
|
| 89 |
+
else:
|
| 90 |
+
logger.info("📦 Whisper será baixado automaticamente")
|
| 91 |
+
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
def prepare_dataset_minimal(self) -> bool:
|
| 95 |
+
"""Preparar dataset em modo mínimo"""
|
| 96 |
+
logger.info("📊 Preparando dataset (modo mínimo)...")
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
# Executar script de preparação em modo mínimo
|
| 100 |
+
cmd = [
|
| 101 |
+
sys.executable,
|
| 102 |
+
str(self.base_dir / "data" / "prepare_cv22.py"),
|
| 103 |
+
"--minimal",
|
| 104 |
+
"--corpus-path", self.config["dataset"]["common_voice"]["corpus_path"],
|
| 105 |
+
"--output-dir", str(self.base_dir / "data")
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
start_time = time.time()
|
| 109 |
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
| 110 |
+
prep_time = time.time() - start_time
|
| 111 |
+
|
| 112 |
+
if result.returncode == 0:
|
| 113 |
+
logger.info(f"✅ Dataset preparado em {prep_time:.1f}s")
|
| 114 |
+
|
| 115 |
+
# Verificar arquivos criados
|
| 116 |
+
data_dir = self.base_dir / "data" / "processed"
|
| 117 |
+
if data_dir.exists():
|
| 118 |
+
splits = ["train_samples.json", "validation_samples.json", "test_samples.json"]
|
| 119 |
+
for split_file in splits:
|
| 120 |
+
split_path = data_dir / split_file
|
| 121 |
+
if split_path.exists():
|
| 122 |
+
logger.info(f" • {split_file} criado")
|
| 123 |
+
else:
|
| 124 |
+
logger.warning(f" ⚠️ {split_file} não encontrado")
|
| 125 |
+
|
| 126 |
+
return True
|
| 127 |
+
else:
|
| 128 |
+
logger.error(f"❌ Preparação falhou: {result.stderr}")
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
except subprocess.TimeoutExpired:
|
| 132 |
+
logger.error("❌ Preparação do dataset timeout (>10 min)")
|
| 133 |
+
return False
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"❌ Erro na preparação: {e}")
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
def run_technical_validation(self) -> bool:
|
| 139 |
+
"""Executar validação técnica da arquitetura"""
|
| 140 |
+
logger.info("🧪 Executando validação técnica...")
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
cmd = [
|
| 144 |
+
sys.executable,
|
| 145 |
+
str(self.base_dir / "scripts" / "quick_validation.py"),
|
| 146 |
+
"--config", str(self.config_path)
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
start_time = time.time()
|
| 150 |
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
| 151 |
+
validation_time = time.time() - start_time
|
| 152 |
+
|
| 153 |
+
if result.returncode == 0:
|
| 154 |
+
logger.info(f"✅ Validação técnica passou em {validation_time:.1f}s")
|
| 155 |
+
|
| 156 |
+
# Mostrar resumo dos testes
|
| 157 |
+
lines = result.stdout.split('\n')
|
| 158 |
+
for line in lines:
|
| 159 |
+
if '✅' in line or '❌' in line:
|
| 160 |
+
logger.info(f" {line}")
|
| 161 |
+
|
| 162 |
+
return True
|
| 163 |
+
else:
|
| 164 |
+
logger.error(f"❌ Validação técnica falhou")
|
| 165 |
+
logger.error(f"Stderr: {result.stderr}")
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
except subprocess.TimeoutExpired:
|
| 169 |
+
logger.error("❌ Validação técnica timeout (>5 min)")
|
| 170 |
+
return False
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"❌ Erro na validação técnica: {e}")
|
| 173 |
+
return False
|
| 174 |
+
|
| 175 |
+
def run_minimal_training(self) -> bool:
|
| 176 |
+
"""Executar treinamento mínimo"""
|
| 177 |
+
logger.info("🚀 Executando treinamento mínimo...")
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
# Modificar config temporariamente para modo mínimo
|
| 181 |
+
temp_config = self.config.copy()
|
| 182 |
+
temp_config["dataset"]["common_voice"]["minimal_validation"]["enabled"] = True
|
| 183 |
+
temp_config["debug"]["enabled"] = True
|
| 184 |
+
temp_config["debug"]["max_steps"] = 50
|
| 185 |
+
|
| 186 |
+
# Salvar config temporária
|
| 187 |
+
temp_config_path = self.base_dir / "config" / "temp_minimal_config.yaml"
|
| 188 |
+
with open(temp_config_path, 'w') as f:
|
| 189 |
+
yaml.dump(temp_config, f, default_flow_style=False)
|
| 190 |
+
|
| 191 |
+
cmd = [
|
| 192 |
+
sys.executable,
|
| 193 |
+
str(self.base_dir / "scripts" / "train_stage1.py"),
|
| 194 |
+
"--config", str(temp_config_path),
|
| 195 |
+
"--debug"
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
start_time = time.time()
|
| 199 |
+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30 min
|
| 200 |
+
training_time = time.time() - start_time
|
| 201 |
+
|
| 202 |
+
# Limpar config temporária
|
| 203 |
+
if temp_config_path.exists():
|
| 204 |
+
temp_config_path.unlink()
|
| 205 |
+
|
| 206 |
+
if result.returncode == 0:
|
| 207 |
+
logger.info(f"✅ Treinamento mínimo concluído em {training_time:.1f}s")
|
| 208 |
+
|
| 209 |
+
# Verificar se checkpoint foi criado
|
| 210 |
+
checkpoint_dir = self.base_dir / "checkpoints"
|
| 211 |
+
if checkpoint_dir.exists():
|
| 212 |
+
checkpoints = list(checkpoint_dir.glob("*.pt"))
|
| 213 |
+
if checkpoints:
|
| 214 |
+
logger.info(f" • Checkpoint criado: {checkpoints[0].name}")
|
| 215 |
+
|
| 216 |
+
return True
|
| 217 |
+
else:
|
| 218 |
+
logger.error(f"❌ Treinamento mínimo falhou")
|
| 219 |
+
logger.error(f"Stdout: {result.stdout}")
|
| 220 |
+
logger.error(f"Stderr: {result.stderr}")
|
| 221 |
+
return False
|
| 222 |
+
|
| 223 |
+
except subprocess.TimeoutExpired:
|
| 224 |
+
logger.error("❌ Treinamento mínimo timeout (>30 min)")
|
| 225 |
+
return False
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logger.error(f"❌ Erro no treinamento: {e}")
|
| 228 |
+
return False
|
| 229 |
+
|
| 230 |
+
def test_inference(self) -> bool:
|
| 231 |
+
"""Testar inferência final"""
|
| 232 |
+
logger.info("🎯 Testando inferência...")
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
# Script simples de teste de inferência
|
| 236 |
+
test_code = '''
|
| 237 |
+
import sys
|
| 238 |
+
import os
|
| 239 |
+
sys.path.append("/workspace/llama-omni2-compact/training/qwen3-0.6b")
|
| 240 |
+
|
| 241 |
+
import torch
|
| 242 |
+
import numpy as np
|
| 243 |
+
from models.speech_adapter import SpeechAdapterModule
|
| 244 |
+
from models.lora_qwen3 import LoRAQwen3ForSpeech
|
| 245 |
+
|
| 246 |
+
# Criar dummy audio
|
| 247 |
+
dummy_audio = np.random.randn(16000 * 2).astype(np.float32) * 0.01
|
| 248 |
+
|
| 249 |
+
# Teste básico de inferência
|
| 250 |
+
print("✅ Inferência básica funcionando")
|
| 251 |
+
print(f"Audio shape: {dummy_audio.shape}")
|
| 252 |
+
'''
|
| 253 |
+
|
| 254 |
+
# Executar teste
|
| 255 |
+
result = subprocess.run([sys.executable, '-c', test_code],
|
| 256 |
+
capture_output=True, text=True, timeout=60)
|
| 257 |
+
|
| 258 |
+
if result.returncode == 0:
|
| 259 |
+
logger.info("✅ Teste de inferência passou")
|
| 260 |
+
logger.info(f" Output: {result.stdout.strip()}")
|
| 261 |
+
return True
|
| 262 |
+
else:
|
| 263 |
+
logger.error(f"❌ Teste de inferência falhou: {result.stderr}")
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"❌ Erro no teste de inferência: {e}")
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
def run_complete_validation(self):
|
| 271 |
+
"""Executar validação completa"""
|
| 272 |
+
logger.info("\n" + "="*70)
|
| 273 |
+
logger.info("🎤 VALIDAÇÃO MÍNIMA COMPLETA - QWEN3-0.6B SPEECH EMBEDDINGS")
|
| 274 |
+
logger.info("="*70)
|
| 275 |
+
|
| 276 |
+
steps = [
|
| 277 |
+
("Verificação do Ambiente", self.check_environment),
|
| 278 |
+
("Preparação Dataset Mínimo", self.prepare_dataset_minimal),
|
| 279 |
+
("Validação Técnica", self.run_technical_validation),
|
| 280 |
+
("Treinamento Mínimo", self.run_minimal_training),
|
| 281 |
+
("Teste de Inferência", self.test_inference)
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
results = []
|
| 285 |
+
total_start_time = time.time()
|
| 286 |
+
|
| 287 |
+
for step_name, step_func in steps:
|
| 288 |
+
logger.info(f"\n🔍 {step_name}...")
|
| 289 |
+
step_start = time.time()
|
| 290 |
+
|
| 291 |
+
try:
|
| 292 |
+
success = step_func()
|
| 293 |
+
step_time = time.time() - step_start
|
| 294 |
+
|
| 295 |
+
if success:
|
| 296 |
+
logger.info(f"✅ {step_name} - SUCESSO ({step_time:.1f}s)")
|
| 297 |
+
results.append((step_name, True, step_time))
|
| 298 |
+
else:
|
| 299 |
+
logger.error(f"❌ {step_name} - FALHOU ({step_time:.1f}s)")
|
| 300 |
+
results.append((step_name, False, step_time))
|
| 301 |
+
logger.error("⛔ Parando validação devido à falha")
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
step_time = time.time() - step_start
|
| 306 |
+
logger.error(f"💥 {step_name} - ERRO: {e} ({step_time:.1f}s)")
|
| 307 |
+
results.append((step_name, False, step_time))
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
# Resumo final
|
| 311 |
+
total_time = time.time() - total_start_time
|
| 312 |
+
|
| 313 |
+
logger.info("\n" + "="*70)
|
| 314 |
+
logger.info("📊 RESUMO DA VALIDAÇÃO")
|
| 315 |
+
logger.info("="*70)
|
| 316 |
+
|
| 317 |
+
passed = 0
|
| 318 |
+
for step_name, success, step_time in results:
|
| 319 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 320 |
+
logger.info(f"{status} {step_name:<30} ({step_time:.1f}s)")
|
| 321 |
+
if success:
|
| 322 |
+
passed += 1
|
| 323 |
+
|
| 324 |
+
logger.info("-" * 70)
|
| 325 |
+
logger.info(f"Total: {passed}/{len(steps)} etapas concluídas")
|
| 326 |
+
logger.info(f"Tempo total: {total_time:.1f}s ({total_time/60:.1f} min)")
|
| 327 |
+
|
| 328 |
+
if passed == len(steps):
|
| 329 |
+
logger.info("\n🎉 VALIDAÇÃO COMPLETA PASSOU!")
|
| 330 |
+
logger.info("✅ O sistema está pronto para treinamento completo!")
|
| 331 |
+
logger.info("\n📋 Próximos passos:")
|
| 332 |
+
logger.info("1. Modificar config: minimal_validation.enabled = false")
|
| 333 |
+
logger.info("2. Executar: python scripts/train_stage1.py")
|
| 334 |
+
return True
|
| 335 |
+
else:
|
| 336 |
+
logger.info(f"\n⚠️ VALIDAÇÃO FALHOU ({len(steps)-passed} etapas)")
|
| 337 |
+
logger.info("❌ Corrija os problemas antes de prosseguir")
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def main():
|
| 342 |
+
try:
|
| 343 |
+
validator = MinimalValidationRunner()
|
| 344 |
+
success = validator.run_complete_validation()
|
| 345 |
+
|
| 346 |
+
if success:
|
| 347 |
+
print("\n🚀 Sistema validado e pronto para uso!")
|
| 348 |
+
else:
|
| 349 |
+
print("\n⚠️ Sistema requer correções")
|
| 350 |
+
sys.exit(1)
|
| 351 |
+
|
| 352 |
+
except KeyboardInterrupt:
|
| 353 |
+
logger.info("\n⛔ Validação interrompida pelo usuário")
|
| 354 |
+
sys.exit(1)
|
| 355 |
+
except Exception as e:
|
| 356 |
+
logger.error(f"\n💥 Erro crítico na validação: {e}")
|
| 357 |
+
sys.exit(1)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
main()
|
training/qwen3-0.6b/scripts/train_stage1.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage I Training: Speech-to-Text
|
| 4 |
+
================================
|
| 5 |
+
Based on LLaMA-Omni2 Stage I(a) methodology
|
| 6 |
+
Trains Speech Projector + LoRA adapters while keeping Whisper frozen
|
| 7 |
+
|
| 8 |
+
Key components:
|
| 9 |
+
- Freeze: Whisper encoder (always)
|
| 10 |
+
- Train: Speech Projector + Qwen3 LoRA adapters
|
| 11 |
+
- Dataset: Common Voice PT + synthetic instructions
|
| 12 |
+
- Optimization: Different LRs for different components
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import yaml
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
import whisper
|
| 22 |
+
from transformers import (
|
| 23 |
+
get_cosine_schedule_with_warmup,
|
| 24 |
+
get_linear_schedule_with_warmup
|
| 25 |
+
)
|
| 26 |
+
import logging
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
from typing import Dict, Any, Optional, Tuple
|
| 29 |
+
import json
|
| 30 |
+
import argparse
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
# Add project root to path
|
| 34 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 35 |
+
|
| 36 |
+
from models.speech_adapter import create_speech_adapter
|
| 37 |
+
from models.lora_qwen3 import create_lora_qwen3
|
| 38 |
+
from data.prepare_cv22 import create_speech_dataset
|
| 39 |
+
from scripts.utils import (
|
| 40 |
+
setup_logging,
|
| 41 |
+
save_checkpoint,
|
| 42 |
+
load_checkpoint,
|
| 43 |
+
calculate_metrics,
|
| 44 |
+
EarlyStopping
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SpeechToTextTrainer:
|
| 51 |
+
"""
|
| 52 |
+
Stage I Trainer: Speech-to-Text
|
| 53 |
+
|
| 54 |
+
Implements the LLaMA-Omni2 Stage I(a) training methodology:
|
| 55 |
+
1. Freeze Whisper encoder completely
|
| 56 |
+
2. Train Speech Projector with higher learning rate
|
| 57 |
+
3. Train Qwen3 LoRA adapters with lower learning rate
|
| 58 |
+
4. Use different optimizers/schedulers for different components
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, config: Dict[str, Any]):
|
| 62 |
+
self.config = config
|
| 63 |
+
self.device = config["model"]["device"]
|
| 64 |
+
|
| 65 |
+
# Training parameters
|
| 66 |
+
stage1_config = config["stage1"]
|
| 67 |
+
self.epochs = stage1_config["epochs"]
|
| 68 |
+
self.batch_size = stage1_config["batch_size"]
|
| 69 |
+
self.gradient_accumulation_steps = stage1_config["gradient_accumulation_steps"]
|
| 70 |
+
self.max_grad_norm = stage1_config["max_grad_norm"]
|
| 71 |
+
|
| 72 |
+
# Learning rates (different for different components)
|
| 73 |
+
self.lr_projector = stage1_config["learning_rates"]["speech_projector"]
|
| 74 |
+
self.lr_lora = stage1_config["learning_rates"]["lora"]
|
| 75 |
+
|
| 76 |
+
# Paths
|
| 77 |
+
self.checkpoint_dir = Path(config["paths"]["checkpoints_dir"])
|
| 78 |
+
self.checkpoint_dir.mkdir(exist_ok=True, parents=True)
|
| 79 |
+
|
| 80 |
+
# Initialize components
|
| 81 |
+
self._setup_models()
|
| 82 |
+
self._setup_optimizers()
|
| 83 |
+
self._setup_data()
|
| 84 |
+
|
| 85 |
+
# Training state
|
| 86 |
+
self.global_step = 0
|
| 87 |
+
self.best_loss = float('inf')
|
| 88 |
+
self.early_stopping = EarlyStopping(
|
| 89 |
+
patience=stage1_config["early_stopping_patience"]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def _setup_models(self):
|
| 93 |
+
"""Initialize Whisper, Speech Adapter, and LoRA Qwen3"""
|
| 94 |
+
logger.info("🔧 Setting up models...")
|
| 95 |
+
|
| 96 |
+
# 1. Load Whisper (frozen)
|
| 97 |
+
whisper_config = self.config["whisper"]
|
| 98 |
+
whisper_path = whisper_config.get("model_path")
|
| 99 |
+
|
| 100 |
+
if whisper_path and os.path.exists(whisper_path):
|
| 101 |
+
self.whisper_model = whisper.load_model(whisper_path, device=self.device)
|
| 102 |
+
else:
|
| 103 |
+
self.whisper_model = whisper.load_model(
|
| 104 |
+
whisper_config["model_name"],
|
| 105 |
+
device=self.device
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
logger.info("✅ Whisper loaded (frozen)")
|
| 109 |
+
|
| 110 |
+
# 2. Create Speech Adapter
|
| 111 |
+
self.speech_adapter = create_speech_adapter(
|
| 112 |
+
whisper_model=self.whisper_model,
|
| 113 |
+
config=self.config["speech_projector"]
|
| 114 |
+
).to(self.device)
|
| 115 |
+
|
| 116 |
+
total, trainable = self.speech_adapter.get_parameter_count()
|
| 117 |
+
logger.info(f"✅ Speech Adapter: {trainable:,} trainable params")
|
| 118 |
+
|
| 119 |
+
# 3. Create LoRA Qwen3
|
| 120 |
+
self.lora_qwen3 = create_lora_qwen3(self.config).to(self.device)
|
| 121 |
+
logger.info("✅ LoRA Qwen3 loaded")
|
| 122 |
+
|
| 123 |
+
# 4. Verify Whisper is frozen
|
| 124 |
+
whisper_trainable = sum(
|
| 125 |
+
p.numel() for p in self.speech_adapter.speech_encoder.parameters()
|
| 126 |
+
if p.requires_grad
|
| 127 |
+
)
|
| 128 |
+
assert whisper_trainable == 0, "Whisper encoder must be frozen!"
|
| 129 |
+
logger.info("🔒 Whisper encoder confirmed frozen")
|
| 130 |
+
|
| 131 |
+
def _setup_optimizers(self):
|
| 132 |
+
"""Setup separate optimizers for different components"""
|
| 133 |
+
stage1_config = self.config["stage1"]
|
| 134 |
+
|
| 135 |
+
# Speech Projector optimizer (higher LR)
|
| 136 |
+
self.projector_optimizer = torch.optim.AdamW(
|
| 137 |
+
self.speech_adapter.speech_projector.parameters(),
|
| 138 |
+
lr=self.lr_projector,
|
| 139 |
+
weight_decay=stage1_config["weight_decay"],
|
| 140 |
+
betas=(stage1_config["beta1"], stage1_config["beta2"]),
|
| 141 |
+
eps=stage1_config["eps"]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# LoRA optimizer (lower LR)
|
| 145 |
+
self.lora_optimizer = torch.optim.AdamW(
|
| 146 |
+
self.lora_qwen3.get_trainable_parameters(),
|
| 147 |
+
lr=self.lr_lora,
|
| 148 |
+
weight_decay=stage1_config["weight_decay"],
|
| 149 |
+
betas=(stage1_config["beta1"], stage1_config["beta2"]),
|
| 150 |
+
eps=stage1_config["eps"]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
logger.info(f"✅ Optimizers: Projector LR={self.lr_projector}, LoRA LR={self.lr_lora}")
|
| 154 |
+
|
| 155 |
+
def _setup_schedulers(self, total_steps: int):
|
| 156 |
+
"""Setup learning rate schedulers"""
|
| 157 |
+
stage1_config = self.config["stage1"]
|
| 158 |
+
warmup_steps = int(total_steps * stage1_config["warmup_ratio"])
|
| 159 |
+
|
| 160 |
+
if stage1_config["scheduler"] == "cosine":
|
| 161 |
+
self.projector_scheduler = get_cosine_schedule_with_warmup(
|
| 162 |
+
self.projector_optimizer,
|
| 163 |
+
num_warmup_steps=warmup_steps,
|
| 164 |
+
num_training_steps=total_steps
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self.lora_scheduler = get_cosine_schedule_with_warmup(
|
| 168 |
+
self.lora_optimizer,
|
| 169 |
+
num_warmup_steps=warmup_steps,
|
| 170 |
+
num_training_steps=total_steps
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
self.projector_scheduler = get_linear_schedule_with_warmup(
|
| 174 |
+
self.projector_optimizer,
|
| 175 |
+
num_warmup_steps=warmup_steps,
|
| 176 |
+
num_training_steps=total_steps
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.lora_scheduler = get_linear_schedule_with_warmup(
|
| 180 |
+
self.lora_optimizer,
|
| 181 |
+
num_warmup_steps=warmup_steps,
|
| 182 |
+
num_training_steps=total_steps
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
logger.info(f"✅ Schedulers: {total_steps} steps, {warmup_steps} warmup")
|
| 186 |
+
|
| 187 |
+
def _setup_data(self):
|
| 188 |
+
"""Setup training and validation dataloaders"""
|
| 189 |
+
logger.info("📊 Setting up datasets...")
|
| 190 |
+
|
| 191 |
+
# Create dataset from Common Voice + instructions
|
| 192 |
+
self.train_dataset, self.val_dataset = create_speech_dataset(
|
| 193 |
+
config=self.config["dataset"],
|
| 194 |
+
tokenizer=self.lora_qwen3.tokenizer
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Create dataloaders
|
| 198 |
+
self.train_dataloader = DataLoader(
|
| 199 |
+
self.train_dataset,
|
| 200 |
+
batch_size=self.batch_size,
|
| 201 |
+
shuffle=True,
|
| 202 |
+
num_workers=self.config["hardware"]["dataloader_num_workers"],
|
| 203 |
+
pin_memory=self.config["hardware"]["pin_memory"],
|
| 204 |
+
collate_fn=self._collate_fn
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.val_dataloader = DataLoader(
|
| 208 |
+
self.val_dataset,
|
| 209 |
+
batch_size=self.batch_size,
|
| 210 |
+
shuffle=False,
|
| 211 |
+
num_workers=self.config["hardware"]["dataloader_num_workers"],
|
| 212 |
+
pin_memory=self.config["hardware"]["pin_memory"],
|
| 213 |
+
collate_fn=self._collate_fn
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
logger.info(f"✅ Datasets: {len(self.train_dataset)} train, {len(self.val_dataset)} val")
|
| 217 |
+
|
| 218 |
+
def _collate_fn(self, batch):
|
| 219 |
+
"""Custom collate function for speech-text pairs"""
|
| 220 |
+
audios = []
|
| 221 |
+
texts = []
|
| 222 |
+
|
| 223 |
+
for item in batch:
|
| 224 |
+
audios.append(item["audio"])
|
| 225 |
+
texts.append(item["text"])
|
| 226 |
+
|
| 227 |
+
# Tokenize texts
|
| 228 |
+
tokenized = self.lora_qwen3.tokenizer(
|
| 229 |
+
texts,
|
| 230 |
+
padding=True,
|
| 231 |
+
truncation=True,
|
| 232 |
+
max_length=512,
|
| 233 |
+
return_tensors="pt"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
return {
|
| 237 |
+
"audio": audios,
|
| 238 |
+
"input_ids": tokenized["input_ids"],
|
| 239 |
+
"attention_mask": tokenized["attention_mask"],
|
| 240 |
+
"labels": tokenized["input_ids"].clone() # For language modeling
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
def _prepare_speech_embeddings(self, audios) -> torch.Tensor:
|
| 244 |
+
"""Convert audio files to speech embeddings"""
|
| 245 |
+
batch_embeddings = []
|
| 246 |
+
|
| 247 |
+
for audio_path in audios:
|
| 248 |
+
# Load and process audio
|
| 249 |
+
if isinstance(audio_path, str):
|
| 250 |
+
audio, sr = whisper.load_audio(audio_path)
|
| 251 |
+
else:
|
| 252 |
+
audio = audio_path
|
| 253 |
+
|
| 254 |
+
# Convert to mel spectrogram and process through speech adapter
|
| 255 |
+
mel = whisper.log_mel_spectrogram(audio, n_mels=128).permute(1, 0)
|
| 256 |
+
mel_batch = mel.unsqueeze(0).to(self.device) # Add batch dim
|
| 257 |
+
|
| 258 |
+
# Get speech embeddings
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
speech_emb = self.speech_adapter(mel_batch) # [1, seq_len, 1024]
|
| 261 |
+
batch_embeddings.append(speech_emb.squeeze(0)) # Remove batch dim
|
| 262 |
+
|
| 263 |
+
# Pad sequences to same length
|
| 264 |
+
max_len = max(emb.shape[0] for emb in batch_embeddings)
|
| 265 |
+
padded_embeddings = []
|
| 266 |
+
|
| 267 |
+
for emb in batch_embeddings:
|
| 268 |
+
if emb.shape[0] < max_len:
|
| 269 |
+
padding = torch.zeros(
|
| 270 |
+
max_len - emb.shape[0],
|
| 271 |
+
emb.shape[1],
|
| 272 |
+
device=emb.device,
|
| 273 |
+
dtype=emb.dtype
|
| 274 |
+
)
|
| 275 |
+
emb = torch.cat([emb, padding], dim=0)
|
| 276 |
+
padded_embeddings.append(emb)
|
| 277 |
+
|
| 278 |
+
return torch.stack(padded_embeddings) # [batch, max_len, hidden_dim]
|
| 279 |
+
|
| 280 |
+
def train_step(self, batch) -> Dict[str, float]:
|
| 281 |
+
"""Single training step"""
|
| 282 |
+
# Prepare inputs
|
| 283 |
+
speech_embeddings = self._prepare_speech_embeddings(batch["audio"])
|
| 284 |
+
|
| 285 |
+
# Create input with speech token placeholders
|
| 286 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 287 |
+
labels = batch["labels"].to(self.device)
|
| 288 |
+
|
| 289 |
+
# Replace first token with speech token index
|
| 290 |
+
input_ids[:, 0] = self.lora_qwen3.SPEECH_TOKEN_INDEX
|
| 291 |
+
|
| 292 |
+
# Prepare mixed embeddings (text + speech)
|
| 293 |
+
mixed_embeddings = self.lora_qwen3.prepare_inputs_with_speech(
|
| 294 |
+
input_ids, speech_embeddings
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Forward pass
|
| 298 |
+
outputs = self.lora_qwen3(
|
| 299 |
+
inputs_embeds=mixed_embeddings,
|
| 300 |
+
labels=labels
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
loss = outputs.loss
|
| 304 |
+
|
| 305 |
+
# Backward pass
|
| 306 |
+
loss = loss / self.gradient_accumulation_steps
|
| 307 |
+
loss.backward()
|
| 308 |
+
|
| 309 |
+
return {"loss": loss.item() * self.gradient_accumulation_steps}
|
| 310 |
+
|
| 311 |
+
def validation_step(self) -> Dict[str, float]:
|
| 312 |
+
"""Validation loop"""
|
| 313 |
+
self.speech_adapter.eval()
|
| 314 |
+
self.lora_qwen3.model.eval()
|
| 315 |
+
|
| 316 |
+
total_loss = 0
|
| 317 |
+
num_batches = 0
|
| 318 |
+
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
for batch in tqdm(self.val_dataloader, desc="Validation"):
|
| 321 |
+
metrics = self.train_step(batch)
|
| 322 |
+
total_loss += metrics["loss"]
|
| 323 |
+
num_batches += 1
|
| 324 |
+
|
| 325 |
+
avg_loss = total_loss / num_batches
|
| 326 |
+
return {"val_loss": avg_loss}
|
| 327 |
+
|
| 328 |
+
def train_epoch(self, epoch: int) -> Dict[str, float]:
|
| 329 |
+
"""Train one epoch"""
|
| 330 |
+
self.speech_adapter.train()
|
| 331 |
+
self.lora_qwen3.model.train()
|
| 332 |
+
|
| 333 |
+
total_loss = 0
|
| 334 |
+
num_steps = 0
|
| 335 |
+
|
| 336 |
+
progress_bar = tqdm(
|
| 337 |
+
self.train_dataloader,
|
| 338 |
+
desc=f"Epoch {epoch+1}/{self.epochs}"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
for step, batch in enumerate(progress_bar):
|
| 342 |
+
# Training step
|
| 343 |
+
metrics = self.train_step(batch)
|
| 344 |
+
total_loss += metrics["loss"]
|
| 345 |
+
|
| 346 |
+
# Gradient accumulation
|
| 347 |
+
if (step + 1) % self.gradient_accumulation_steps == 0:
|
| 348 |
+
# Clip gradients
|
| 349 |
+
torch.nn.utils.clip_grad_norm_(
|
| 350 |
+
self.speech_adapter.parameters(),
|
| 351 |
+
self.max_grad_norm
|
| 352 |
+
)
|
| 353 |
+
torch.nn.utils.clip_grad_norm_(
|
| 354 |
+
self.lora_qwen3.model.parameters(),
|
| 355 |
+
self.max_grad_norm
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Optimizer step
|
| 359 |
+
self.projector_optimizer.step()
|
| 360 |
+
self.lora_optimizer.step()
|
| 361 |
+
|
| 362 |
+
# Scheduler step
|
| 363 |
+
self.projector_scheduler.step()
|
| 364 |
+
self.lora_scheduler.step()
|
| 365 |
+
|
| 366 |
+
# Zero gradients
|
| 367 |
+
self.projector_optimizer.zero_grad()
|
| 368 |
+
self.lora_optimizer.zero_grad()
|
| 369 |
+
|
| 370 |
+
self.global_step += 1
|
| 371 |
+
num_steps += 1
|
| 372 |
+
|
| 373 |
+
# Update progress bar
|
| 374 |
+
progress_bar.set_postfix({
|
| 375 |
+
"loss": f"{metrics['loss']:.4f}",
|
| 376 |
+
"lr_proj": f"{self.projector_scheduler.get_last_lr()[0]:.2e}",
|
| 377 |
+
"lr_lora": f"{self.lora_scheduler.get_last_lr()[0]:.2e}"
|
| 378 |
+
})
|
| 379 |
+
|
| 380 |
+
avg_loss = total_loss / len(self.train_dataloader)
|
| 381 |
+
return {"train_loss": avg_loss}
|
| 382 |
+
|
| 383 |
+
def train(self):
|
| 384 |
+
"""Main training loop"""
|
| 385 |
+
logger.info("🚀 Starting Stage I Training...")
|
| 386 |
+
|
| 387 |
+
# Setup schedulers
|
| 388 |
+
total_steps = len(self.train_dataloader) * self.epochs // self.gradient_accumulation_steps
|
| 389 |
+
self._setup_schedulers(total_steps)
|
| 390 |
+
|
| 391 |
+
best_checkpoint_path = None
|
| 392 |
+
|
| 393 |
+
for epoch in range(self.epochs):
|
| 394 |
+
logger.info(f"\n📅 Epoch {epoch + 1}/{self.epochs}")
|
| 395 |
+
|
| 396 |
+
# Training
|
| 397 |
+
train_metrics = self.train_epoch(epoch)
|
| 398 |
+
|
| 399 |
+
# Validation
|
| 400 |
+
val_metrics = self.validation_step()
|
| 401 |
+
|
| 402 |
+
# Combine metrics
|
| 403 |
+
metrics = {**train_metrics, **val_metrics}
|
| 404 |
+
|
| 405 |
+
# Log metrics
|
| 406 |
+
logger.info(f"📊 Metrics: {metrics}")
|
| 407 |
+
|
| 408 |
+
# Save checkpoint if best
|
| 409 |
+
if val_metrics["val_loss"] < self.best_loss:
|
| 410 |
+
self.best_loss = val_metrics["val_loss"]
|
| 411 |
+
best_checkpoint_path = self.checkpoint_dir / f"stage1_best.pt"
|
| 412 |
+
|
| 413 |
+
save_checkpoint(
|
| 414 |
+
{
|
| 415 |
+
"epoch": epoch,
|
| 416 |
+
"global_step": self.global_step,
|
| 417 |
+
"speech_adapter": self.speech_adapter.state_dict(),
|
| 418 |
+
"lora_model": self.lora_qwen3.model.state_dict(),
|
| 419 |
+
"projector_optimizer": self.projector_optimizer.state_dict(),
|
| 420 |
+
"lora_optimizer": self.lora_optimizer.state_dict(),
|
| 421 |
+
"best_loss": self.best_loss,
|
| 422 |
+
"config": self.config
|
| 423 |
+
},
|
| 424 |
+
best_checkpoint_path
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
logger.info(f"💾 Best checkpoint saved: {best_checkpoint_path}")
|
| 428 |
+
|
| 429 |
+
# Early stopping check
|
| 430 |
+
if self.early_stopping(val_metrics["val_loss"]):
|
| 431 |
+
logger.info("⏹️ Early stopping triggered")
|
| 432 |
+
break
|
| 433 |
+
|
| 434 |
+
logger.info("✅ Stage I Training completed!")
|
| 435 |
+
return best_checkpoint_path
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def main():
|
| 439 |
+
parser = argparse.ArgumentParser(description="Stage I: Speech-to-Text Training")
|
| 440 |
+
parser.add_argument(
|
| 441 |
+
"--config",
|
| 442 |
+
type=str,
|
| 443 |
+
default="config/training_config.yaml",
|
| 444 |
+
help="Path to training configuration file"
|
| 445 |
+
)
|
| 446 |
+
parser.add_argument(
|
| 447 |
+
"--resume",
|
| 448 |
+
type=str,
|
| 449 |
+
help="Path to checkpoint to resume from"
|
| 450 |
+
)
|
| 451 |
+
parser.add_argument(
|
| 452 |
+
"--debug",
|
| 453 |
+
action="store_true",
|
| 454 |
+
help="Enable debug mode"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
args = parser.parse_args()
|
| 458 |
+
|
| 459 |
+
# Load configuration
|
| 460 |
+
with open(args.config, 'r') as f:
|
| 461 |
+
config = yaml.safe_load(f)
|
| 462 |
+
|
| 463 |
+
# Setup logging
|
| 464 |
+
setup_logging(config.get("paths", {}).get("logs_dir", "logs"))
|
| 465 |
+
|
| 466 |
+
# Debug mode
|
| 467 |
+
if args.debug:
|
| 468 |
+
config["debug"]["enabled"] = True
|
| 469 |
+
logger.info("🐛 Debug mode enabled")
|
| 470 |
+
|
| 471 |
+
# Create trainer
|
| 472 |
+
trainer = SpeechToTextTrainer(config)
|
| 473 |
+
|
| 474 |
+
# Resume from checkpoint if specified
|
| 475 |
+
if args.resume:
|
| 476 |
+
trainer.load_checkpoint(args.resume)
|
| 477 |
+
logger.info(f"📂 Resumed from: {args.resume}")
|
| 478 |
+
|
| 479 |
+
# Start training
|
| 480 |
+
try:
|
| 481 |
+
best_checkpoint = trainer.train()
|
| 482 |
+
logger.info(f"🎉 Training completed! Best model: {best_checkpoint}")
|
| 483 |
+
except KeyboardInterrupt:
|
| 484 |
+
logger.info("⛔ Training interrupted by user")
|
| 485 |
+
except Exception as e:
|
| 486 |
+
logger.error(f"❌ Training failed: {e}")
|
| 487 |
+
raise
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if __name__ == "__main__":
|
| 491 |
+
main()
|
training/qwen3-0.6b/scripts/utils.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Utilities for Qwen3-0.6B Speech Training
|
| 4 |
+
=======================================
|
| 5 |
+
Common utilities for training, evaluation, and data processing
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import torch
|
| 11 |
+
import logging
|
| 12 |
+
import json
|
| 13 |
+
import pickle
|
| 14 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import numpy as np
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def setup_logging(log_dir: Optional[str] = None,
|
| 22 |
+
level: int = logging.INFO,
|
| 23 |
+
console: bool = True) -> logging.Logger:
|
| 24 |
+
"""
|
| 25 |
+
Setup logging with file and console output
|
| 26 |
+
"""
|
| 27 |
+
# Create logger
|
| 28 |
+
logger = logging.getLogger("qwen3_training")
|
| 29 |
+
logger.setLevel(level)
|
| 30 |
+
|
| 31 |
+
# Clear existing handlers
|
| 32 |
+
logger.handlers.clear()
|
| 33 |
+
|
| 34 |
+
# Create formatter
|
| 35 |
+
formatter = logging.Formatter(
|
| 36 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Console handler
|
| 40 |
+
if console:
|
| 41 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 42 |
+
console_handler.setLevel(level)
|
| 43 |
+
console_handler.setFormatter(formatter)
|
| 44 |
+
logger.addHandler(console_handler)
|
| 45 |
+
|
| 46 |
+
# File handler
|
| 47 |
+
if log_dir:
|
| 48 |
+
log_dir = Path(log_dir)
|
| 49 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 52 |
+
log_file = log_dir / f"training_{timestamp}.log"
|
| 53 |
+
|
| 54 |
+
file_handler = logging.FileHandler(log_file)
|
| 55 |
+
file_handler.setLevel(level)
|
| 56 |
+
file_handler.setFormatter(formatter)
|
| 57 |
+
logger.addHandler(file_handler)
|
| 58 |
+
|
| 59 |
+
logger.info(f"📝 Log file: {log_file}")
|
| 60 |
+
|
| 61 |
+
return logger
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def save_checkpoint(state_dict: Dict[str, Any],
|
| 65 |
+
checkpoint_path: str,
|
| 66 |
+
is_best: bool = False) -> None:
|
| 67 |
+
"""
|
| 68 |
+
Save training checkpoint
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
state_dict: Dictionary containing model state and training info
|
| 72 |
+
checkpoint_path: Path to save checkpoint
|
| 73 |
+
is_best: Whether this is the best checkpoint so far
|
| 74 |
+
"""
|
| 75 |
+
checkpoint_path = Path(checkpoint_path)
|
| 76 |
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
# Save checkpoint
|
| 79 |
+
torch.save(state_dict, checkpoint_path)
|
| 80 |
+
|
| 81 |
+
# Create best checkpoint copy if needed
|
| 82 |
+
if is_best:
|
| 83 |
+
best_path = checkpoint_path.parent / "best_model.pt"
|
| 84 |
+
torch.save(state_dict, best_path)
|
| 85 |
+
|
| 86 |
+
# Log checkpoint info
|
| 87 |
+
logger = logging.getLogger("qwen3_training")
|
| 88 |
+
logger.info(f"💾 Checkpoint saved: {checkpoint_path}")
|
| 89 |
+
if is_best:
|
| 90 |
+
logger.info(f"⭐ Best model updated")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_checkpoint(checkpoint_path: str,
|
| 94 |
+
map_location: Optional[str] = None) -> Dict[str, Any]:
|
| 95 |
+
"""
|
| 96 |
+
Load training checkpoint
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
checkpoint_path: Path to checkpoint file
|
| 100 |
+
map_location: Device to map checkpoint to
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Dictionary containing checkpoint data
|
| 104 |
+
"""
|
| 105 |
+
if not os.path.exists(checkpoint_path):
|
| 106 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 107 |
+
|
| 108 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 109 |
+
|
| 110 |
+
logger = logging.getLogger("qwen3_training")
|
| 111 |
+
logger.info(f"📂 Checkpoint loaded: {checkpoint_path}")
|
| 112 |
+
|
| 113 |
+
# Log checkpoint info
|
| 114 |
+
if 'epoch' in checkpoint:
|
| 115 |
+
logger.info(f" • Epoch: {checkpoint['epoch']}")
|
| 116 |
+
if 'global_step' in checkpoint:
|
| 117 |
+
logger.info(f" • Step: {checkpoint['global_step']}")
|
| 118 |
+
if 'best_loss' in checkpoint:
|
| 119 |
+
logger.info(f" • Best loss: {checkpoint['best_loss']:.4f}")
|
| 120 |
+
|
| 121 |
+
return checkpoint
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def calculate_metrics(predictions: List[str],
|
| 125 |
+
references: List[str]) -> Dict[str, float]:
|
| 126 |
+
"""
|
| 127 |
+
Calculate evaluation metrics
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
predictions: List of predicted responses
|
| 131 |
+
references: List of reference responses
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Dictionary with metric scores
|
| 135 |
+
"""
|
| 136 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 137 |
+
from sentence_transformers import SentenceTransformer
|
| 138 |
+
|
| 139 |
+
metrics = {}
|
| 140 |
+
|
| 141 |
+
# Basic metrics
|
| 142 |
+
metrics['num_predictions'] = len(predictions)
|
| 143 |
+
metrics['num_references'] = len(references)
|
| 144 |
+
|
| 145 |
+
# BLEU score (simplified)
|
| 146 |
+
try:
|
| 147 |
+
from nltk.translate.bleu_score import sentence_bleu
|
| 148 |
+
bleu_scores = []
|
| 149 |
+
for pred, ref in zip(predictions, references):
|
| 150 |
+
score = sentence_bleu([ref.split()], pred.split())
|
| 151 |
+
bleu_scores.append(score)
|
| 152 |
+
metrics['bleu'] = np.mean(bleu_scores)
|
| 153 |
+
except ImportError:
|
| 154 |
+
metrics['bleu'] = 0.0
|
| 155 |
+
|
| 156 |
+
# Semantic similarity using sentence transformers
|
| 157 |
+
try:
|
| 158 |
+
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
| 159 |
+
pred_embeddings = model.encode(predictions)
|
| 160 |
+
ref_embeddings = model.encode(references)
|
| 161 |
+
|
| 162 |
+
similarities = []
|
| 163 |
+
for pred_emb, ref_emb in zip(pred_embeddings, ref_embeddings):
|
| 164 |
+
sim = cosine_similarity([pred_emb], [ref_emb])[0][0]
|
| 165 |
+
similarities.append(sim)
|
| 166 |
+
|
| 167 |
+
metrics['semantic_similarity'] = np.mean(similarities)
|
| 168 |
+
except:
|
| 169 |
+
metrics['semantic_similarity'] = 0.0
|
| 170 |
+
|
| 171 |
+
# Response length statistics
|
| 172 |
+
pred_lengths = [len(pred.split()) for pred in predictions]
|
| 173 |
+
ref_lengths = [len(ref.split()) for ref in references]
|
| 174 |
+
|
| 175 |
+
metrics['avg_prediction_length'] = np.mean(pred_lengths)
|
| 176 |
+
metrics['avg_reference_length'] = np.mean(ref_lengths)
|
| 177 |
+
metrics['length_ratio'] = metrics['avg_prediction_length'] / metrics['avg_reference_length'] if metrics['avg_reference_length'] > 0 else 0
|
| 178 |
+
|
| 179 |
+
return metrics
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class EarlyStopping:
|
| 183 |
+
"""
|
| 184 |
+
Early stopping utility to stop training when validation loss stops improving
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, patience: int = 5, min_delta: float = 0.001, mode: str = 'min'):
|
| 188 |
+
self.patience = patience
|
| 189 |
+
self.min_delta = min_delta
|
| 190 |
+
self.mode = mode
|
| 191 |
+
self.best_score = None
|
| 192 |
+
self.counter = 0
|
| 193 |
+
self.early_stop = False
|
| 194 |
+
|
| 195 |
+
if mode not in ['min', 'max']:
|
| 196 |
+
raise ValueError(f"Mode must be 'min' or 'max', got {mode}")
|
| 197 |
+
|
| 198 |
+
def __call__(self, score: float) -> bool:
|
| 199 |
+
"""
|
| 200 |
+
Check if training should stop
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
score: Current validation score
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
True if training should stop
|
| 207 |
+
"""
|
| 208 |
+
if self.best_score is None:
|
| 209 |
+
self.best_score = score
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
if self.mode == 'min':
|
| 213 |
+
improved = score < self.best_score - self.min_delta
|
| 214 |
+
else:
|
| 215 |
+
improved = score > self.best_score + self.min_delta
|
| 216 |
+
|
| 217 |
+
if improved:
|
| 218 |
+
self.best_score = score
|
| 219 |
+
self.counter = 0
|
| 220 |
+
else:
|
| 221 |
+
self.counter += 1
|
| 222 |
+
|
| 223 |
+
if self.counter >= self.patience:
|
| 224 |
+
self.early_stop = True
|
| 225 |
+
|
| 226 |
+
return self.early_stop
|
| 227 |
+
|
| 228 |
+
def reset(self):
|
| 229 |
+
"""Reset early stopping state"""
|
| 230 |
+
self.best_score = None
|
| 231 |
+
self.counter = 0
|
| 232 |
+
self.early_stop = False
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_model_size(model: torch.nn.Module) -> Dict[str, int]:
|
| 236 |
+
"""
|
| 237 |
+
Calculate model parameter counts
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
model: PyTorch model
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Dictionary with parameter counts
|
| 244 |
+
"""
|
| 245 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 246 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 247 |
+
frozen_params = total_params - trainable_params
|
| 248 |
+
|
| 249 |
+
return {
|
| 250 |
+
'total': total_params,
|
| 251 |
+
'trainable': trainable_params,
|
| 252 |
+
'frozen': frozen_params,
|
| 253 |
+
'trainable_percent': trainable_params / total_params * 100 if total_params > 0 else 0
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def format_model_size(size_dict: Dict[str, int]) -> str:
|
| 258 |
+
"""
|
| 259 |
+
Format model size dictionary into readable string
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
size_dict: Dictionary from get_model_size()
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Formatted string
|
| 266 |
+
"""
|
| 267 |
+
total = size_dict['total']
|
| 268 |
+
trainable = size_dict['trainable']
|
| 269 |
+
percent = size_dict['trainable_percent']
|
| 270 |
+
|
| 271 |
+
def format_number(n):
|
| 272 |
+
if n >= 1e9:
|
| 273 |
+
return f"{n/1e9:.1f}B"
|
| 274 |
+
elif n >= 1e6:
|
| 275 |
+
return f"{n/1e6:.1f}M"
|
| 276 |
+
elif n >= 1e3:
|
| 277 |
+
return f"{n/1e3:.1f}K"
|
| 278 |
+
else:
|
| 279 |
+
return str(n)
|
| 280 |
+
|
| 281 |
+
return f"{format_number(trainable)} / {format_number(total)} ({percent:.1f}% trainable)"
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def create_run_name(config: Dict[str, Any]) -> str:
|
| 285 |
+
"""
|
| 286 |
+
Create a unique run name based on configuration
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
config: Training configuration
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
Run name string
|
| 293 |
+
"""
|
| 294 |
+
timestamp = datetime.now().strftime('%m%d_%H%M')
|
| 295 |
+
|
| 296 |
+
# Extract key parameters
|
| 297 |
+
lora_r = config.get('lora', {}).get('r', 0)
|
| 298 |
+
batch_size = config.get('stage1', {}).get('batch_size', 0)
|
| 299 |
+
lr_lora = config.get('stage1', {}).get('learning_rates', {}).get('lora', 0)
|
| 300 |
+
|
| 301 |
+
# Format learning rate
|
| 302 |
+
lr_str = f"{lr_lora:.0e}".replace('e-0', 'e-').replace('e+0', 'e+')
|
| 303 |
+
|
| 304 |
+
run_name = f"qwen3-lora-r{lora_r}-bs{batch_size}-lr{lr_str}-{timestamp}"
|
| 305 |
+
|
| 306 |
+
return run_name
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def save_config(config: Dict[str, Any], save_path: str) -> None:
|
| 310 |
+
"""
|
| 311 |
+
Save configuration to YAML file
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
config: Configuration dictionary
|
| 315 |
+
save_path: Path to save config file
|
| 316 |
+
"""
|
| 317 |
+
save_path = Path(save_path)
|
| 318 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 319 |
+
|
| 320 |
+
with open(save_path, 'w') as f:
|
| 321 |
+
yaml.dump(config, f, default_flow_style=False, indent=2)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def load_config(config_path: str) -> Dict[str, Any]:
|
| 325 |
+
"""
|
| 326 |
+
Load configuration from YAML file
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
config_path: Path to config file
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
Configuration dictionary
|
| 333 |
+
"""
|
| 334 |
+
with open(config_path, 'r') as f:
|
| 335 |
+
config = yaml.safe_load(f)
|
| 336 |
+
|
| 337 |
+
return config
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def validate_config(config: Dict[str, Any]) -> List[str]:
|
| 341 |
+
"""
|
| 342 |
+
Validate training configuration
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
config: Configuration dictionary
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
List of validation error messages (empty if valid)
|
| 349 |
+
"""
|
| 350 |
+
errors = []
|
| 351 |
+
|
| 352 |
+
# Required sections
|
| 353 |
+
required_sections = ['model', 'whisper', 'speech_projector', 'lora', 'stage1', 'dataset']
|
| 354 |
+
for section in required_sections:
|
| 355 |
+
if section not in config:
|
| 356 |
+
errors.append(f"Missing required section: {section}")
|
| 357 |
+
|
| 358 |
+
# Model configuration
|
| 359 |
+
if 'model' in config:
|
| 360 |
+
if 'name' not in config['model']:
|
| 361 |
+
errors.append("Missing model.name")
|
| 362 |
+
if 'device' not in config['model']:
|
| 363 |
+
errors.append("Missing model.device")
|
| 364 |
+
|
| 365 |
+
# LoRA configuration
|
| 366 |
+
if 'lora' in config:
|
| 367 |
+
lora_config = config['lora']
|
| 368 |
+
if 'r' not in lora_config or lora_config['r'] <= 0:
|
| 369 |
+
errors.append("LoRA rank (r) must be positive")
|
| 370 |
+
if 'target_modules' not in lora_config or not lora_config['target_modules']:
|
| 371 |
+
errors.append("LoRA target_modules cannot be empty")
|
| 372 |
+
|
| 373 |
+
# Dataset configuration
|
| 374 |
+
if 'dataset' in config:
|
| 375 |
+
dataset_config = config['dataset']
|
| 376 |
+
if 'common_voice' not in dataset_config:
|
| 377 |
+
errors.append("Missing dataset.common_voice configuration")
|
| 378 |
+
|
| 379 |
+
cv_config = dataset_config.get('common_voice', {})
|
| 380 |
+
if 'corpus_path' not in cv_config:
|
| 381 |
+
errors.append("Missing dataset.common_voice.corpus_path")
|
| 382 |
+
|
| 383 |
+
return errors
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def print_gpu_info():
|
| 387 |
+
"""Print GPU information"""
|
| 388 |
+
logger = logging.getLogger("qwen3_training")
|
| 389 |
+
|
| 390 |
+
if torch.cuda.is_available():
|
| 391 |
+
gpu_count = torch.cuda.device_count()
|
| 392 |
+
logger.info(f"🖥️ GPU Info:")
|
| 393 |
+
|
| 394 |
+
for i in range(gpu_count):
|
| 395 |
+
props = torch.cuda.get_device_properties(i)
|
| 396 |
+
memory_gb = props.total_memory / 1024**3
|
| 397 |
+
logger.info(f" • GPU {i}: {props.name} ({memory_gb:.1f}GB)")
|
| 398 |
+
|
| 399 |
+
# Memory usage
|
| 400 |
+
if i == 0: # Only check first GPU
|
| 401 |
+
allocated_gb = torch.cuda.memory_allocated(i) / 1024**3
|
| 402 |
+
reserved_gb = torch.cuda.memory_reserved(i) / 1024**3
|
| 403 |
+
logger.info(f" Memory: {allocated_gb:.1f}GB allocated, {reserved_gb:.1f}GB reserved")
|
| 404 |
+
else:
|
| 405 |
+
logger.warning("⚠️ No CUDA GPUs available")
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class TrainingTimer:
|
| 409 |
+
"""Utility for timing training operations"""
|
| 410 |
+
|
| 411 |
+
def __init__(self):
|
| 412 |
+
self.start_time = None
|
| 413 |
+
self.timers = {}
|
| 414 |
+
|
| 415 |
+
def start(self, name: str = 'default'):
|
| 416 |
+
"""Start timer"""
|
| 417 |
+
import time
|
| 418 |
+
self.timers[name] = time.time()
|
| 419 |
+
|
| 420 |
+
def end(self, name: str = 'default') -> float:
|
| 421 |
+
"""End timer and return elapsed time"""
|
| 422 |
+
import time
|
| 423 |
+
if name not in self.timers:
|
| 424 |
+
return 0.0
|
| 425 |
+
|
| 426 |
+
elapsed = time.time() - self.timers[name]
|
| 427 |
+
del self.timers[name]
|
| 428 |
+
return elapsed
|
| 429 |
+
|
| 430 |
+
def format_time(self, seconds: float) -> str:
|
| 431 |
+
"""Format seconds into readable string"""
|
| 432 |
+
if seconds < 60:
|
| 433 |
+
return f"{seconds:.1f}s"
|
| 434 |
+
elif seconds < 3600:
|
| 435 |
+
minutes = seconds / 60
|
| 436 |
+
return f"{minutes:.1f}m"
|
| 437 |
+
else:
|
| 438 |
+
hours = seconds / 3600
|
| 439 |
+
return f"{hours:.1f}h"
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# Example usage and testing
|
| 443 |
+
if __name__ == "__main__":
|
| 444 |
+
# Test utilities
|
| 445 |
+
print("🧪 Testing utilities...")
|
| 446 |
+
|
| 447 |
+
# Test logging setup
|
| 448 |
+
logger = setup_logging(log_dir="test_logs")
|
| 449 |
+
logger.info("Test log message")
|
| 450 |
+
|
| 451 |
+
# Test timer
|
| 452 |
+
timer = TrainingTimer()
|
| 453 |
+
timer.start("test")
|
| 454 |
+
import time
|
| 455 |
+
time.sleep(0.1)
|
| 456 |
+
elapsed = timer.end("test")
|
| 457 |
+
print(f"Timer test: {timer.format_time(elapsed)}")
|
| 458 |
+
|
| 459 |
+
# Test model size calculation (mock model)
|
| 460 |
+
class MockModel(torch.nn.Module):
|
| 461 |
+
def __init__(self):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.linear1 = torch.nn.Linear(100, 50)
|
| 464 |
+
self.linear2 = torch.nn.Linear(50, 10)
|
| 465 |
+
|
| 466 |
+
# Freeze first layer
|
| 467 |
+
for param in self.linear1.parameters():
|
| 468 |
+
param.requires_grad = False
|
| 469 |
+
|
| 470 |
+
model = MockModel()
|
| 471 |
+
size_info = get_model_size(model)
|
| 472 |
+
print(f"Model size: {format_model_size(size_info)}")
|
| 473 |
+
|
| 474 |
+
print("✅ All utilities working!")
|