Spaces:
Sleeping
Sleeping
Commit
·
0ec3f34
1
Parent(s):
3225a94
refactor: update training code, history and remove large checkpoint to fix storage limit
Browse files- .gitignore +1 -1
- outputs/training_history.json +60 -60
- requirements.txt +17 -9
- src/inference/factory.py +10 -9
- src/models/factory.py +38 -19
- start_training.bat +4 -0
.gitignore
CHANGED
|
@@ -30,7 +30,7 @@ data/cache/
|
|
| 30 |
|
| 31 |
# Models
|
| 32 |
checkpoints/*.pt
|
| 33 |
-
!checkpoints/best.pt
|
| 34 |
*.pth
|
| 35 |
*.ckpt
|
| 36 |
!artifacts/*.json
|
|
|
|
| 30 |
|
| 31 |
# Models
|
| 32 |
checkpoints/*.pt
|
| 33 |
+
# !checkpoints/best.pt
|
| 34 |
*.pth
|
| 35 |
*.ckpt
|
| 36 |
!artifacts/*.json
|
outputs/training_history.json
CHANGED
|
@@ -1,92 +1,92 @@
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
-
"summarization_loss":
|
| 4 |
-
"summarization_rouge_like": 0.
|
| 5 |
-
"emotion_loss": 0.
|
| 6 |
-
"emotion_f1": 0.
|
| 7 |
-
"topic_loss": 0.
|
| 8 |
-
"topic_accuracy": 0.
|
| 9 |
"epoch": 1.0
|
| 10 |
},
|
| 11 |
"val_epoch_1": {
|
| 12 |
-
"summarization_loss":
|
| 13 |
-
"summarization_rouge_like": 0.
|
| 14 |
-
"emotion_loss": 0.
|
| 15 |
-
"emotion_f1": 0.
|
| 16 |
-
"topic_loss": 0.
|
| 17 |
-
"topic_accuracy": 0.
|
| 18 |
"epoch": 1.0
|
| 19 |
},
|
| 20 |
"train_epoch_2": {
|
| 21 |
-
"summarization_loss":
|
| 22 |
-
"summarization_rouge_like": 0.
|
| 23 |
-
"emotion_loss": 0.
|
| 24 |
-
"emotion_f1": 0.
|
| 25 |
-
"topic_loss": 0.
|
| 26 |
-
"topic_accuracy": 0.
|
| 27 |
"epoch": 2.0
|
| 28 |
},
|
| 29 |
"val_epoch_2": {
|
| 30 |
-
"summarization_loss":
|
| 31 |
-
"summarization_rouge_like": 0.
|
| 32 |
-
"emotion_loss": 0.
|
| 33 |
-
"emotion_f1": 0.
|
| 34 |
-
"topic_loss": 0.
|
| 35 |
-
"topic_accuracy": 0.
|
| 36 |
"epoch": 2.0
|
| 37 |
},
|
| 38 |
"train_epoch_3": {
|
| 39 |
-
"summarization_loss":
|
| 40 |
-
"summarization_rouge_like": 0.
|
| 41 |
-
"emotion_loss": 0.
|
| 42 |
-
"emotion_f1": 0.
|
| 43 |
-
"topic_loss": 0.
|
| 44 |
-
"topic_accuracy": 0.
|
| 45 |
"epoch": 3.0
|
| 46 |
},
|
| 47 |
"val_epoch_3": {
|
| 48 |
-
"summarization_loss":
|
| 49 |
-
"summarization_rouge_like": 0.
|
| 50 |
-
"emotion_loss": 0.
|
| 51 |
-
"emotion_f1": 0.
|
| 52 |
-
"topic_loss": 0.
|
| 53 |
-
"topic_accuracy": 0.
|
| 54 |
"epoch": 3.0
|
| 55 |
},
|
| 56 |
"train_epoch_4": {
|
| 57 |
-
"summarization_loss":
|
| 58 |
-
"summarization_rouge_like": 0.
|
| 59 |
-
"emotion_loss": 0.
|
| 60 |
-
"emotion_f1": 0.
|
| 61 |
-
"topic_loss": 0.
|
| 62 |
-
"topic_accuracy": 0.
|
| 63 |
"epoch": 4.0
|
| 64 |
},
|
| 65 |
"val_epoch_4": {
|
| 66 |
-
"summarization_loss":
|
| 67 |
-
"summarization_rouge_like": 0.
|
| 68 |
-
"emotion_loss": 0.
|
| 69 |
-
"emotion_f1": 0.
|
| 70 |
-
"topic_loss": 0.
|
| 71 |
-
"topic_accuracy": 0.
|
| 72 |
"epoch": 4.0
|
| 73 |
},
|
| 74 |
"train_epoch_5": {
|
| 75 |
-
"summarization_loss":
|
| 76 |
-
"summarization_rouge_like": 0.
|
| 77 |
-
"emotion_loss": 0.
|
| 78 |
-
"emotion_f1": 0.
|
| 79 |
-
"topic_loss": 0.
|
| 80 |
-
"topic_accuracy": 0.
|
| 81 |
"epoch": 5.0
|
| 82 |
},
|
| 83 |
"val_epoch_5": {
|
| 84 |
-
"summarization_loss":
|
| 85 |
-
"summarization_rouge_like": 0.
|
| 86 |
-
"emotion_loss": 0.
|
| 87 |
-
"emotion_f1": 0.
|
| 88 |
-
"topic_loss": 0.
|
| 89 |
-
"topic_accuracy": 0.
|
| 90 |
"epoch": 5.0
|
| 91 |
}
|
| 92 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"train_epoch_1": {
|
| 3 |
+
"summarization_loss": 5.023585737518827,
|
| 4 |
+
"summarization_rouge_like": 0.19371884805954312,
|
| 5 |
+
"emotion_loss": 0.0821188951971249,
|
| 6 |
+
"emotion_f1": 0.865718169566,
|
| 7 |
+
"topic_loss": 0.24917707448061954,
|
| 8 |
+
"topic_accuracy": 0.9192776539426024,
|
| 9 |
"epoch": 1.0
|
| 10 |
},
|
| 11 |
"val_epoch_1": {
|
| 12 |
+
"summarization_loss": 3.7266472615858954,
|
| 13 |
+
"summarization_rouge_like": 0.2827026719016518,
|
| 14 |
+
"emotion_loss": 0.14450823713558134,
|
| 15 |
+
"emotion_f1": 0.9086874146293125,
|
| 16 |
+
"topic_loss": 0.21787223087735602,
|
| 17 |
+
"topic_accuracy": 0.9326002393776182,
|
| 18 |
"epoch": 1.0
|
| 19 |
},
|
| 20 |
"train_epoch_2": {
|
| 21 |
+
"summarization_loss": 3.398382334982861,
|
| 22 |
+
"summarization_rouge_like": 0.31421210196164595,
|
| 23 |
+
"emotion_loss": 0.008744604070504772,
|
| 24 |
+
"emotion_f1": 0.9922616565848632,
|
| 25 |
+
"topic_loss": 0.12368396144345378,
|
| 26 |
+
"topic_accuracy": 0.9631060183895236,
|
| 27 |
"epoch": 2.0
|
| 28 |
},
|
| 29 |
"val_epoch_2": {
|
| 30 |
+
"summarization_loss": 2.728874285017067,
|
| 31 |
+
"summarization_rouge_like": 0.3867885960963845,
|
| 32 |
+
"emotion_loss": 0.20949344621063382,
|
| 33 |
+
"emotion_f1": 0.9095850804121747,
|
| 34 |
+
"topic_loss": 0.2887416907434674,
|
| 35 |
+
"topic_accuracy": 0.9329742669060442,
|
| 36 |
"epoch": 2.0
|
| 37 |
},
|
| 38 |
"train_epoch_3": {
|
| 39 |
+
"summarization_loss": 2.699047506134568,
|
| 40 |
+
"summarization_rouge_like": 0.38349341261349945,
|
| 41 |
+
"emotion_loss": 0.005096756787117961,
|
| 42 |
+
"emotion_f1": 0.9953213525834805,
|
| 43 |
+
"topic_loss": 0.07009015341349616,
|
| 44 |
+
"topic_accuracy": 0.9802800222903316,
|
| 45 |
"epoch": 3.0
|
| 46 |
},
|
| 47 |
"val_epoch_3": {
|
| 48 |
+
"summarization_loss": 2.354555403451446,
|
| 49 |
+
"summarization_rouge_like": 0.4275408038759501,
|
| 50 |
+
"emotion_loss": 0.20089952317384335,
|
| 51 |
+
"emotion_f1": 0.9075279304326329,
|
| 52 |
+
"topic_loss": 0.4845805834182202,
|
| 53 |
+
"topic_accuracy": 0.9298324356672651,
|
| 54 |
"epoch": 3.0
|
| 55 |
},
|
| 56 |
"train_epoch_4": {
|
| 57 |
+
"summarization_loss": 2.3750830047009015,
|
| 58 |
+
"summarization_rouge_like": 0.4200744394095619,
|
| 59 |
+
"emotion_loss": 0.0037049090056492364,
|
| 60 |
+
"emotion_f1": 0.9962315410599798,
|
| 61 |
+
"topic_loss": 0.042221361385891144,
|
| 62 |
+
"topic_accuracy": 0.9888652828085818,
|
| 63 |
"epoch": 4.0
|
| 64 |
},
|
| 65 |
"val_epoch_4": {
|
| 66 |
+
"summarization_loss": 2.198225014299636,
|
| 67 |
+
"summarization_rouge_like": 0.444635960654823,
|
| 68 |
+
"emotion_loss": 0.20359252842952202,
|
| 69 |
+
"emotion_f1": 0.9163175773506461,
|
| 70 |
+
"topic_loss": 0.5501026207833392,
|
| 71 |
+
"topic_accuracy": 0.9272890484739676,
|
| 72 |
"epoch": 4.0
|
| 73 |
},
|
| 74 |
"train_epoch_5": {
|
| 75 |
+
"summarization_loss": 2.186419085976007,
|
| 76 |
+
"summarization_rouge_like": 0.4416556068282783,
|
| 77 |
+
"emotion_loss": 0.0030099891204739266,
|
| 78 |
+
"emotion_f1": 0.9964672148443591,
|
| 79 |
+
"topic_loss": 0.03006078401232904,
|
| 80 |
+
"topic_accuracy": 0.9925606018389523,
|
| 81 |
"epoch": 5.0
|
| 82 |
},
|
| 83 |
"val_epoch_5": {
|
| 84 |
+
"summarization_loss": 2.114973693461849,
|
| 85 |
+
"summarization_rouge_like": 0.4553148986859889,
|
| 86 |
+
"emotion_loss": 0.2197709748711572,
|
| 87 |
+
"emotion_f1": 0.9121534032496345,
|
| 88 |
+
"topic_loss": 0.6607796598369469,
|
| 89 |
+
"topic_accuracy": 0.931178934769599,
|
| 90 |
"epoch": 5.0
|
| 91 |
}
|
| 92 |
}
|
requirements.txt
CHANGED
|
@@ -1,15 +1,23 @@
|
|
| 1 |
# requirements.txt
|
| 2 |
torch>=2.0.0
|
| 3 |
-
transformers>=4.
|
| 4 |
-
|
|
|
|
| 5 |
numpy>=1.24.0
|
| 6 |
pandas>=2.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
streamlit>=1.25.0
|
| 8 |
plotly>=5.18.0
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
seaborn
|
| 13 |
-
pytest
|
| 14 |
-
matplotlib
|
| 15 |
-
rouge-score>=0.1.2
|
|
|
|
| 1 |
# requirements.txt
|
| 2 |
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
datasets>=2.14.0
|
| 5 |
+
tokenizers>=0.13.0
|
| 6 |
numpy>=1.24.0
|
| 7 |
pandas>=2.0.0
|
| 8 |
+
scikit-learn>=1.3.0
|
| 9 |
+
matplotlib>=3.7.0
|
| 10 |
+
seaborn>=0.12.0
|
| 11 |
+
nltk>=3.8.0
|
| 12 |
+
tqdm>=4.65.0
|
| 13 |
+
pyyaml>=6.0
|
| 14 |
+
omegaconf>=2.3.0
|
| 15 |
+
tensorboard>=2.13.0
|
| 16 |
+
gradio>=3.35.0
|
| 17 |
+
requests>=2.31.0
|
| 18 |
+
kaggle>=1.5.12
|
| 19 |
streamlit>=1.25.0
|
| 20 |
plotly>=5.18.0
|
| 21 |
+
faiss-cpu==1.9.0; platform_system != "Windows"
|
| 22 |
+
faiss-cpu==1.9.0; platform_system == "Windows"
|
| 23 |
+
huggingface_hub>=0.19.0
|
|
|
|
|
|
|
|
|
|
|
|
src/inference/factory.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Tuple
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
|
|
|
| 9 |
from ..data.tokenization import Tokenizer, TokenizerConfig
|
| 10 |
from ..models.factory import ModelConfig, build_multitask_model, load_model_config
|
| 11 |
from ..utils.io import load_state
|
|
@@ -45,24 +46,23 @@ def create_inference_pipeline(
|
|
| 45 |
)
|
| 46 |
|
| 47 |
tokenizer = Tokenizer(resolved_tokenizer_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
model_config = load_model_config(model_config_path)
|
| 49 |
model = build_multitask_model(
|
| 50 |
tokenizer,
|
| 51 |
num_emotions=labels.emotion_size,
|
| 52 |
num_topics=labels.topic_size,
|
| 53 |
config=model_config,
|
|
|
|
| 54 |
)
|
|
|
|
|
|
|
| 55 |
load_state(model, str(checkpoint))
|
| 56 |
|
| 57 |
-
# Tie weights manually to ensure decoder output projection matches embeddings
|
| 58 |
-
# This fixes issues where the output projection might be untrained or mismatched
|
| 59 |
-
decoder = getattr(model, "decoder", None)
|
| 60 |
-
output_projection = getattr(decoder, "output_projection", None) if decoder is not None else None
|
| 61 |
-
embedding = getattr(decoder, "embedding", None) if decoder is not None else None
|
| 62 |
-
|
| 63 |
-
if output_projection is not None and embedding is not None:
|
| 64 |
-
output_projection.weight = embedding.weight
|
| 65 |
-
|
| 66 |
if isinstance(device, torch.device):
|
| 67 |
device_str = str(device)
|
| 68 |
else:
|
|
@@ -80,5 +80,6 @@ def create_inference_pipeline(
|
|
| 80 |
emotion_labels=labels.emotion,
|
| 81 |
topic_labels=labels.topic,
|
| 82 |
device=device,
|
|
|
|
| 83 |
)
|
| 84 |
return pipeline, labels
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
from ..data.preprocessing import TextPreprocessor
|
| 10 |
from ..data.tokenization import Tokenizer, TokenizerConfig
|
| 11 |
from ..models.factory import ModelConfig, build_multitask_model, load_model_config
|
| 12 |
from ..utils.io import load_state
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
tokenizer = Tokenizer(resolved_tokenizer_config)
|
| 49 |
+
|
| 50 |
+
# Default to base config if not specified (checkpoint was trained with base config)
|
| 51 |
+
if model_config_path is None:
|
| 52 |
+
model_config_path = Path(__file__).resolve().parent.parent.parent / "configs" / "model" / "base.yaml"
|
| 53 |
+
|
| 54 |
model_config = load_model_config(model_config_path)
|
| 55 |
model = build_multitask_model(
|
| 56 |
tokenizer,
|
| 57 |
num_emotions=labels.emotion_size,
|
| 58 |
num_topics=labels.topic_size,
|
| 59 |
config=model_config,
|
| 60 |
+
load_pretrained=False,
|
| 61 |
)
|
| 62 |
+
|
| 63 |
+
# Load checkpoint - weights will load separately since factory doesn't tie them
|
| 64 |
load_state(model, str(checkpoint))
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if isinstance(device, torch.device):
|
| 67 |
device_str = str(device)
|
| 68 |
else:
|
|
|
|
| 80 |
emotion_labels=labels.emotion,
|
| 81 |
topic_labels=labels.topic,
|
| 82 |
device=device,
|
| 83 |
+
preprocessor=TextPreprocessor(tokenizer=tokenizer, lowercase=tokenizer.config.lower),
|
| 84 |
)
|
| 85 |
return pipeline, labels
|
src/models/factory.py
CHANGED
|
@@ -69,7 +69,8 @@ def _load_pretrained_weights(encoder: TransformerEncoder, decoder: TransformerDe
|
|
| 69 |
# Load encoder weights
|
| 70 |
print("Transferring encoder weights...")
|
| 71 |
encoder.embedding.weight.data.copy_(bart.encoder.embed_tokens.weight.data)
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
for i, (custom_layer, bart_layer) in enumerate(zip(encoder.layers, bart.encoder.layers)):
|
| 75 |
# Self-attention
|
|
@@ -88,19 +89,22 @@ def _load_pretrained_weights(encoder: TransformerEncoder, decoder: TransformerDe
|
|
| 88 |
custom_layer.norm2.weight.data.copy_(bart_layer.final_layer_norm.weight.data)
|
| 89 |
custom_layer.norm2.bias.data.copy_(bart_layer.final_layer_norm.bias.data)
|
| 90 |
|
| 91 |
-
# FFN
|
| 92 |
-
custom_layer.ffn.
|
| 93 |
-
custom_layer.ffn.
|
| 94 |
-
custom_layer.ffn.
|
| 95 |
-
custom_layer.ffn.
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# Load decoder weights
|
| 101 |
print("Transferring decoder weights...")
|
| 102 |
decoder.embedding.weight.data.copy_(bart.decoder.embed_tokens.weight.data)
|
| 103 |
-
|
| 104 |
|
| 105 |
for i, (custom_layer, bart_layer) in enumerate(zip(decoder.layers, bart.decoder.layers)):
|
| 106 |
# Self-attention
|
|
@@ -131,14 +135,16 @@ def _load_pretrained_weights(encoder: TransformerEncoder, decoder: TransformerDe
|
|
| 131 |
custom_layer.norm3.weight.data.copy_(bart_layer.final_layer_norm.weight.data)
|
| 132 |
custom_layer.norm3.bias.data.copy_(bart_layer.final_layer_norm.bias.data)
|
| 133 |
|
| 134 |
-
# FFN
|
| 135 |
-
custom_layer.ffn.
|
| 136 |
-
custom_layer.ffn.
|
| 137 |
-
custom_layer.ffn.
|
| 138 |
-
custom_layer.ffn.
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
|
| 143 |
print("Pretrained weights loaded successfully!")
|
| 144 |
|
|
@@ -149,8 +155,17 @@ def build_multitask_model(
|
|
| 149 |
num_emotions: int,
|
| 150 |
num_topics: int,
|
| 151 |
config: ModelConfig | None = None,
|
|
|
|
| 152 |
) -> MultiTaskModel:
|
| 153 |
-
"""Construct the multitask transformer with heads for the three tasks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
cfg = config or ModelConfig()
|
| 156 |
if not isinstance(num_emotions, int) or num_emotions <= 0:
|
|
@@ -179,10 +194,14 @@ def build_multitask_model(
|
|
| 179 |
pad_token_id=tokenizer.pad_token_id,
|
| 180 |
)
|
| 181 |
|
| 182 |
-
# Load pretrained weights if requested
|
| 183 |
-
|
|
|
|
| 184 |
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 185 |
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 188 |
model.add_head(
|
|
|
|
| 69 |
# Load encoder weights
|
| 70 |
print("Transferring encoder weights...")
|
| 71 |
encoder.embedding.weight.data.copy_(bart.encoder.embed_tokens.weight.data)
|
| 72 |
+
# Skip positional encoding - BART uses learned positions, I use sinusoidal
|
| 73 |
+
# implementation will work fine with sinusoidal encodings
|
| 74 |
|
| 75 |
for i, (custom_layer, bart_layer) in enumerate(zip(encoder.layers, bart.encoder.layers)):
|
| 76 |
# Self-attention
|
|
|
|
| 89 |
custom_layer.norm2.weight.data.copy_(bart_layer.final_layer_norm.weight.data)
|
| 90 |
custom_layer.norm2.bias.data.copy_(bart_layer.final_layer_norm.bias.data)
|
| 91 |
|
| 92 |
+
# FFN - use linear1/linear2
|
| 93 |
+
custom_layer.ffn.linear1.weight.data.copy_(bart_layer.fc1.weight.data)
|
| 94 |
+
custom_layer.ffn.linear1.bias.data.copy_(bart_layer.fc1.bias.data)
|
| 95 |
+
custom_layer.ffn.linear2.weight.data.copy_(bart_layer.fc2.weight.data)
|
| 96 |
+
custom_layer.ffn.linear2.bias.data.copy_(bart_layer.fc2.bias.data)
|
| 97 |
|
| 98 |
+
# BART has layernorm_embedding at the input, I have final_norm at output
|
| 99 |
+
# Copy it to final_norm - not a perfect match but close enough for transfer learning
|
| 100 |
+
if hasattr(bart.encoder, 'layernorm_embedding'):
|
| 101 |
+
encoder.final_norm.weight.data.copy_(bart.encoder.layernorm_embedding.weight.data)
|
| 102 |
+
encoder.final_norm.bias.data.copy_(bart.encoder.layernorm_embedding.bias.data)
|
| 103 |
|
| 104 |
# Load decoder weights
|
| 105 |
print("Transferring decoder weights...")
|
| 106 |
decoder.embedding.weight.data.copy_(bart.decoder.embed_tokens.weight.data)
|
| 107 |
+
# Skip positional encoding - BART uses learned positions, we use sinusoidal
|
| 108 |
|
| 109 |
for i, (custom_layer, bart_layer) in enumerate(zip(decoder.layers, bart.decoder.layers)):
|
| 110 |
# Self-attention
|
|
|
|
| 135 |
custom_layer.norm3.weight.data.copy_(bart_layer.final_layer_norm.weight.data)
|
| 136 |
custom_layer.norm3.bias.data.copy_(bart_layer.final_layer_norm.bias.data)
|
| 137 |
|
| 138 |
+
# FFN - use linear1/linear2 (not fc1/fc2)
|
| 139 |
+
custom_layer.ffn.linear1.weight.data.copy_(bart_layer.fc1.weight.data)
|
| 140 |
+
custom_layer.ffn.linear1.bias.data.copy_(bart_layer.fc1.bias.data)
|
| 141 |
+
custom_layer.ffn.linear2.weight.data.copy_(bart_layer.fc2.weight.data)
|
| 142 |
+
custom_layer.ffn.linear2.bias.data.copy_(bart_layer.fc2.bias.data)
|
| 143 |
|
| 144 |
+
# BART has layernorm_embedding at the input, we have final_norm at output
|
| 145 |
+
if hasattr(bart.decoder, 'layernorm_embedding'):
|
| 146 |
+
decoder.final_norm.weight.data.copy_(bart.decoder.layernorm_embedding.weight.data)
|
| 147 |
+
decoder.final_norm.bias.data.copy_(bart.decoder.layernorm_embedding.bias.data)
|
| 148 |
|
| 149 |
print("Pretrained weights loaded successfully!")
|
| 150 |
|
|
|
|
| 155 |
num_emotions: int,
|
| 156 |
num_topics: int,
|
| 157 |
config: ModelConfig | None = None,
|
| 158 |
+
load_pretrained: bool | None = None,
|
| 159 |
) -> MultiTaskModel:
|
| 160 |
+
"""Construct the multitask transformer with heads for the three tasks.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
tokenizer: Tokenizer for vocabulary size and pad token
|
| 164 |
+
num_emotions: Number of emotion classes
|
| 165 |
+
num_topics: Number of topic classes
|
| 166 |
+
config: Model architecture configuration
|
| 167 |
+
load_pretrained: Override config.use_pretrained (for inference to skip loading)
|
| 168 |
+
"""
|
| 169 |
|
| 170 |
cfg = config or ModelConfig()
|
| 171 |
if not isinstance(num_emotions, int) or num_emotions <= 0:
|
|
|
|
| 194 |
pad_token_id=tokenizer.pad_token_id,
|
| 195 |
)
|
| 196 |
|
| 197 |
+
# Load pretrained weights if requested (but allow override for inference)
|
| 198 |
+
should_load = cfg.use_pretrained if load_pretrained is None else load_pretrained
|
| 199 |
+
if should_load:
|
| 200 |
_load_pretrained_weights(encoder, decoder, cfg.pretrained_model_name)
|
| 201 |
|
| 202 |
+
# NOTE: Weight tying disabled because the current checkpoint was trained without it
|
| 203 |
+
# For NEW training runs, uncomment this line to enable proper weight tying:
|
| 204 |
+
# decoder.output_projection.weight = decoder.embedding.weight
|
| 205 |
|
| 206 |
model = MultiTaskModel(encoder=encoder, decoder=decoder, decoder_outputs_logits=True)
|
| 207 |
model.add_head(
|
start_training.bat
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
cd /d C:\Users\olive\OneDrive\Desktop\LexiMind\LexiMind
|
| 3 |
+
call C:\Users\olive\OneDrive\Desktop\LexiMind\.venv\Scripts\activate.bat
|
| 4 |
+
python scripts\train.py --training-config configs\training\default.yaml --model-config configs\model\base.yaml --data-config configs\data\datasets.yaml --device cuda > logs\training_live.log 2>&1
|