gapura-oneclick / training /train_nlp.py
Muhammad Ridzki Nugraha
Deploy API and config (Batch 3)
07476a1 verified
"""
Train NLP Models for Gapura Irregularity Reports
Fine-tunes Indonesian BERT for classification, NER, and sentiment analysis
"""
import os
import sys
import logging
import pickle
import json
from datetime import datetime
from typing import List, Dict, Any, Tuple
from dotenv import load_dotenv
load_dotenv()
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
Trainer,
TrainingArguments,
DataCollatorWithPadding,
EarlyStoppingCallback,
)
from datasets import Dataset as HFDataset
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings("ignore")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler("training_nlp.log"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.sheets_service import GoogleSheetsService
class SeverityClassifier:
"""Classify severity level from report text"""
def __init__(self, model_name="indobenchmark/indobert-base-p1"):
self.model_name = model_name
self.local_model_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"models",
"nlp",
"indobert-base",
)
self.tokenizer = None
self.model = None
self.label_encoder = LabelEncoder()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _get_model_path(self):
"""Get model path - use local if available and complete"""
config_path = os.path.join(self.local_model_path, "config.json")
if os.path.exists(config_path):
logger.info(f"Using local model: {self.local_model_path}")
return self.local_model_path
logger.info(f"Local model incomplete, using hub: {self.model_name}")
return self.model_name
def prepare_data(self, reports: List[Dict]) -> pd.DataFrame:
"""Prepare data for training"""
data = []
for report in reports:
# Combine text fields
texts = []
for field in ["Report", "Root_Caused", "Action_Taken"]:
text = report.get(field, "")
if text and text != "#N/A":
texts.append(text)
combined_text = " ".join(texts)
if not combined_text or len(combined_text) < 10:
continue
# Determine severity using bilingual weighted keywords with intensifiers and negations
tl = combined_text.lower()
critical = {
"emergency": 1.2,
"darurat": 1.2,
"critical": 1.2,
"kritis": 1.2,
"severe": 1.1,
"parah": 1.1,
"injury": 1.2,
"cedera": 1.2,
"fire": 1.3,
"kebakaran": 1.3,
"explosion": 1.3,
"ledakan": 1.3,
"evacuate": 1.2,
"evakuasi": 1.2,
"safety": 1.1,
"keselamatan": 1.1,
"accident": 1.2,
"kecelakaan": 1.2,
}
high = {
"damage": 1.0,
"rusak": 1.0,
"broken": 1.0,
"pecah": 1.0,
"patah": 1.0,
"torn": 0.9,
"robek": 0.9,
"spillage": 0.9,
"bocor": 0.9,
"lost": 0.8,
"hilang": 0.8,
"stolen": 0.9,
"dicuri": 0.9,
"unsafe": 1.0,
"berbahaya": 1.0,
}
medium = {
"delay": 0.6,
"terlambat": 0.6,
"telat": 0.6,
"late": 0.6,
"misload": 0.6,
"salah muat": 0.6,
"wrong": 0.5,
"incorrect": 0.5,
"keliru": 0.5,
"missing": 0.5,
"tidak ada": 0.5,
"error": 0.5,
"kesalahan": 0.5,
"fail": 0.5,
"gagal": 0.5,
"complaint": 0.5,
"keluhan": 0.5,
"complain": 0.5,
"komplain": 0.5,
}
low = {
"minor": 0.3,
"kecil": 0.3,
"ringan": 0.3,
"normal": 0.2,
"rutin": 0.2,
}
intensifiers = {
"very",
"sangat",
"extremely",
"urgent",
"mendesak",
"segera",
"immediately",
"secepatnya",
}
deintensifiers = {"slight", "sedikit", "minor", "low", "ringan"}
negations = {"no", "not", "tidak", "bukan", "tanpa", "false alarm"}
score = 0.0
for kw, w in critical.items():
if kw in tl:
score += w
for kw, w in high.items():
if kw in tl:
score += w
for kw, w in medium.items():
if kw in tl:
score += w
for kw, w in low.items():
if kw in tl:
score += w
if any(t in tl for t in intensifiers) and score > 0:
score *= 1.2
if any(t in tl for t in deintensifiers):
score *= 0.85
if any(n in tl for n in negations):
score *= 0.7
excl_bonus = min(0.1, tl.count("!") * 0.03)
score += excl_bonus
norm = min(1.0, score / 3.0)
if norm >= 0.75:
severity = "Critical"
elif norm >= 0.5:
severity = "High"
elif norm >= 0.25:
severity = "Medium"
else:
severity = "Low"
data.append(
{
"text": combined_text,
"severity": severity,
"category": report.get("Irregularity_Complain_Category", "Unknown"),
"area": report.get("Area", "Unknown"),
}
)
return pd.DataFrame(data)
def train(self, reports: List[Dict]):
"""Train severity classifier"""
logger.info("=" * 60)
logger.info("Training Severity Classifier")
logger.info("=" * 60)
# Prepare data
df = self.prepare_data(reports)
if len(df) < 20:
logger.error("Not enough data to train NLP model")
return None
logger.info(f"Prepared {len(df)} samples")
logger.info(f"Severity distribution:\n{df['severity'].value_counts()}")
# Encode labels
df["label"] = self.label_encoder.fit_transform(df["severity"])
# Split data
train_df, test_df = train_test_split(
df, test_size=0.2, random_state=42, stratify=df["severity"]
)
logger.info(f"\nTraining samples: {len(train_df)}")
logger.info(f"Test samples: {len(test_df)}")
# Load tokenizer and model
model_path = self._get_model_path()
logger.info(f"\nLoading pre-trained model: {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
num_labels = len(self.label_encoder.classes_)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_path, num_labels=num_labels
)
self.model.to(self.device)
# Tokenize data
def tokenize_function(examples):
return self.tokenizer(
examples["text"], padding="max_length", truncation=True, max_length=512
)
train_dataset = HFDataset.from_pandas(train_df)
test_dataset = HFDataset.from_pandas(test_df)
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
# Set format for PyTorch
train_dataset.set_format(
"torch", columns=["input_ids", "attention_mask", "label"]
)
test_dataset.set_format(
"torch", columns=["input_ids", "attention_mask", "label"]
)
# Compute class weights to handle imbalance
classes = np.unique(df["label"])
class_weights = compute_class_weight(
class_weight="balanced", classes=classes, y=df["label"].values
)
class_weights = torch.tensor(class_weights, dtype=torch.float)
# Training arguments
training_args = TrainingArguments(
output_dir="./results_severity",
num_train_epochs=5,
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
warmup_steps=100,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="none", # Disable wandb
)
# Weighted loss
class WeightedTrainer(Trainer):
def __init__(self, class_weights: torch.Tensor, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights.to(self.model.device)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return (loss, outputs) if return_outputs else loss
# Initialize trainer
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return {
"accuracy": accuracy_score(labels, predictions),
}
trainer = WeightedTrainer(
class_weights=class_weights,
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
# Train
logger.info("\nTraining model...")
trainer.train()
# Evaluate
logger.info("\nEvaluating model...")
eval_results = trainer.evaluate()
logger.info(f"Test Accuracy: {eval_results['eval_accuracy']:.4f}")
# Get predictions for detailed report
predictions = trainer.predict(test_dataset)
pred_labels = np.argmax(predictions.predictions, axis=1)
logger.info("\nClassification Report:")
logger.info(
classification_report(
test_df["label"], pred_labels, target_names=self.label_encoder.classes_
)
)
return {
"accuracy": eval_results["eval_accuracy"],
"num_classes": num_labels,
"training_samples": len(train_df),
"test_samples": len(test_df),
}
def predict(self, texts: List[str]) -> List[Dict]:
"""Predict severity for texts"""
if self.model is None or self.tokenizer is None:
raise ValueError("Model not trained. Call train() first.")
self.model.eval()
results = []
with torch.no_grad():
for text in texts:
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
confidence, pred = torch.max(probs, dim=1)
severity = self.label_encoder.inverse_transform([pred.item()])[0]
results.append(
{
"severity": severity,
"confidence": confidence.item(),
"all_probabilities": {
label: prob.item()
for label, prob in zip(
self.label_encoder.classes_, probs[0]
)
},
}
)
return results
def save(self, filepath: str):
"""Save model"""
logger.info(f"Saving severity classifier to {filepath}...")
model_dir = os.path.dirname(filepath)
os.makedirs(model_dir, exist_ok=True)
# Save model
self.model.save_pretrained(model_dir)
self.tokenizer.save_pretrained(model_dir)
# Save label encoder
with open(os.path.join(model_dir, "label_encoder.pkl"), "wb") as f:
pickle.dump(self.label_encoder, f)
logger.info(f"✓ Model saved to {model_dir}")
def load(self, filepath: str):
"""Load model"""
model_dir = os.path.dirname(filepath) if os.path.isfile(filepath) else filepath
logger.info(f"Loading severity classifier from {model_dir}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
self.model.to(self.device)
with open(os.path.join(model_dir, "label_encoder.pkl"), "rb") as f:
self.label_encoder = pickle.load(f)
logger.info(f"✓ Model loaded")
class IssueTypeClassifier:
"""Classify issue type/category from report text"""
def __init__(self, model_name="indobenchmark/indobert-base-p1"):
self.model_name = model_name
self.tokenizer = None
self.model = None
self.label_encoder = LabelEncoder()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def prepare_data(self, reports: List[Dict]) -> pd.DataFrame:
"""Prepare data for training"""
data = []
for report in reports:
text = report.get("Report", "")
category = report.get("Irregularity_Complain_Category", "")
if not text or not category or len(text) < 10:
continue
data.append({"text": text, "category": category})
return pd.DataFrame(data)
def train(self, reports: List[Dict]):
"""Train issue type classifier"""
logger.info("=" * 60)
logger.info("Training Issue Type Classifier")
logger.info("=" * 60)
df = self.prepare_data(reports)
if len(df) < 20:
logger.error("Not enough data")
return None
# Filter categories with at least 5 samples
category_counts = df["category"].value_counts()
valid_categories = category_counts[category_counts >= 5].index
df = df[df["category"].isin(valid_categories)]
logger.info(f"Prepared {len(df)} samples")
logger.info(f"Categories: {df['category'].nunique()}")
# Encode labels
df["label"] = self.label_encoder.fit_transform(df["category"])
# Split
train_df, test_df = train_test_split(
df, test_size=0.2, random_state=42, stratify=df["category"]
)
# Compute class weights
classes = np.unique(df["label"])
class_weights = compute_class_weight(
class_weight="balanced", classes=classes, y=df["label"].values
)
class_weights = torch.tensor(class_weights, dtype=torch.float)
# Load model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name, num_labels=len(self.label_encoder.classes_)
)
self.model.to(self.device)
# Tokenize
def tokenize_function(examples):
return self.tokenizer(
examples["text"], padding="max_length", truncation=True, max_length=512
)
train_dataset = HFDataset.from_pandas(train_df).map(
tokenize_function, batched=True
)
test_dataset = HFDataset.from_pandas(test_df).map(
tokenize_function, batched=True
)
train_dataset.set_format(
"torch", columns=["input_ids", "attention_mask", "label"]
)
test_dataset.set_format(
"torch", columns=["input_ids", "attention_mask", "label"]
)
# Train
training_args = TrainingArguments(
output_dir="./results_category",
num_train_epochs=5,
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
warmup_steps=50,
weight_decay=0.01,
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
report_to="none",
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return {"accuracy": accuracy_score(labels, predictions)}
class WeightedTrainer(Trainer):
def __init__(self, class_weights: torch.Tensor, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_weights = class_weights.to(self.model.device)
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return (loss, outputs) if return_outputs else loss
trainer = WeightedTrainer(
class_weights=class_weights,
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
trainer.train()
eval_results = trainer.evaluate()
logger.info(f"Test Accuracy: {eval_results['eval_accuracy']:.4f}")
return {"accuracy": eval_results["eval_accuracy"]}
def predict(self, texts: List[str]) -> List[Dict]:
"""Predict issue type"""
if self.model is None:
raise ValueError("Model not trained")
self.model.eval()
results = []
with torch.no_grad():
for text in texts:
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
confidence, pred = torch.max(probs, dim=1)
category = self.label_encoder.inverse_transform([pred.item()])[0]
results.append({"issueType": category, "confidence": confidence.item()})
return results
def save(self, filepath: str):
"""Save model"""
model_dir = os.path.dirname(filepath)
os.makedirs(model_dir, exist_ok=True)
self.model.save_pretrained(model_dir)
self.tokenizer.save_pretrained(model_dir)
with open(os.path.join(model_dir, "label_encoder.pkl"), "wb") as f:
pickle.dump(self.label_encoder, f)
def load(self, filepath: str):
"""Load model"""
model_dir = os.path.dirname(filepath) if os.path.isfile(filepath) else filepath
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
self.model.to(self.device)
with open(os.path.join(model_dir, "label_encoder.pkl"), "rb") as f:
self.label_encoder = pickle.load(f)
class SimpleSummarizer:
"""Simple extractive summarizer (no training needed)"""
def __init__(self):
self.important_keywords = [
"damage",
"torn",
"broken",
"cargo",
"baggage",
"passenger",
"delay",
"late",
"error",
"fail",
"problem",
"issue",
"incident",
]
def summarize(self, text: str, max_sentences: int = 3) -> Dict:
"""Generate extractive summary"""
if not text or len(text) < 50:
return {
"executiveSummary": text[:200] + "..." if len(text) > 200 else text,
"keyPoints": [],
}
# Split into sentences (simple approach)
sentences = text.replace("!", ".").replace("?", ".").split(".")
sentences = [s.strip() for s in sentences if len(s.strip()) > 20]
# Score sentences
sentence_scores = []
for sent in sentences:
score = sum(1 for kw in self.important_keywords if kw in sent.lower())
sentence_scores.append((sent, score))
# Get top sentences
sentence_scores.sort(key=lambda x: x[1], reverse=True)
top_sentences = [s[0] for s in sentence_scores[:max_sentences]]
# Executive summary
summary = ". ".join(top_sentences) + "."
# Key points
key_points = []
if any(kw in text.lower() for kw in ["cargo", "uld"]):
key_points.append("Cargo-related issue")
if any(kw in text.lower() for kw in ["baggage", "bag"]):
key_points.append("Baggage handling issue")
if any(kw in text.lower() for kw in ["passenger", "pax"]):
key_points.append("Passenger service issue")
if any(kw in text.lower() for kw in ["damage", "torn", "broken"]):
key_points.append("Physical damage reported")
return {
"executiveSummary": summary[:300] + "..."
if len(summary) > 300
else summary,
"keyPoints": key_points[:5],
}
def main():
"""Main training script"""
logger.info("=" * 60)
logger.info("Gapura AI - NLP Model Training")
logger.info("=" * 60)
# Initialize
sheets_service = GoogleSheetsService()
spreadsheet_id = os.getenv("GOOGLE_SHEET_ID")
if not spreadsheet_id:
logger.error("GOOGLE_SHEET_ID not set")
sys.exit(1)
# Fetch data
logger.info("\nFetching data...")
try:
non_cargo_data = sheets_service.fetch_sheet_data(
spreadsheet_id, "NON CARGO", "A1:AA500"
)
cargo_data = sheets_service.fetch_sheet_data(spreadsheet_id, "CGO", "A1:Z500")
all_data = non_cargo_data + cargo_data
logger.info(f"✓ Total records: {len(all_data)}")
except Exception as e:
logger.error(f"Error fetching data: {str(e)}")
sys.exit(1)
if len(all_data) < 20:
logger.error("Not enough data")
sys.exit(1)
# Train Severity Classifier
logger.info("\n" + "=" * 60)
severity_clf = SeverityClassifier()
severity_metrics = severity_clf.train(all_data)
if severity_metrics:
severity_clf.save("./models/nlp/severity_classifier")
logger.info(
f"✓ Severity classifier trained (accuracy: {severity_metrics['accuracy']:.4f})"
)
# Train Issue Type Classifier
logger.info("\n" + "=" * 60)
issue_clf = IssueTypeClassifier()
issue_metrics = issue_clf.train(all_data)
if issue_metrics:
issue_clf.save("./models/nlp/issue_classifier")
logger.info(
f"✓ Issue classifier trained (accuracy: {issue_metrics['accuracy']:.4f})"
)
# Save Summarizer (no training needed)
summarizer = SimpleSummarizer()
with open("./models/nlp/summarizer.pkl", "wb") as f:
pickle.dump(summarizer, f)
logger.info("✓ Summarizer saved")
# Save combined metrics
metrics = {
"severity_classifier": severity_metrics,
"issue_classifier": issue_metrics,
"trained_at": datetime.now().isoformat(),
"total_samples": len(all_data),
}
with open("./models/nlp/training_metrics.json", "w") as f:
json.dump(metrics, f, indent=2, default=str)
logger.info("\n" + "=" * 60)
logger.info("NLP Training Complete!")
logger.info("=" * 60)
logger.info("Models saved to: ./models/nlp/")
if __name__ == "__main__":
main()
# ======= Multi-Task Fine-tuning and ONNX Export (Optional) =======
def train_multitask_and_export(
reports: List[Dict],
output_dir: str = "./models",
model_name: str = "distilbert-base-uncased",
epochs: int = 3,
batch_size: int = 8,
max_length: int = 256,
):
"""
Fine-tune DistilBERT with multiple heads (severity, issue_type, area) and export ONNX.
Saves:
- {output_dir}/multi_task_transformer.pt
- {output_dir}/multi_task_transformer.onnx
- {output_dir}/multi_task_label_encoders.pkl
"""
import os
import pickle
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.preprocessing import LabelEncoder
from data.transformer_architecture import MultiTaskDistilBert
os.makedirs(output_dir, exist_ok=True)
# Prepare data with weak labels for severity if missing
texts, sev_labels, issue_labels, area_labels = [], [], [], []
for r in reports:
t = " ".join(
[
str(r.get("Report", "") or ""),
str(r.get("Root_Caused", "") or ""),
str(r.get("Action_Taken", "") or ""),
]
).strip()
if len(t) < 10:
continue
sev = r.get("Severity_Label")
if not sev:
tl = t.lower()
if any(
k in tl
for k in [
"emergency",
"darurat",
"critical",
"kritis",
"fire",
"kebakaran",
"injury",
"cedera",
]
):
sev = "Critical"
elif any(
k in tl for k in ["damage", "rusak", "broken", "pecah", "robek", "torn"]
):
sev = "High"
elif any(
k in tl
for k in ["delay", "terlambat", "telat", "wrong", "incorrect", "error"]
):
sev = "Medium"
else:
sev = "Low"
issue = r.get("Irregularity_Complain_Category", "")
area = (r.get("Area", "") or "").replace(" Area", "")
if not issue or not area:
continue
texts.append(t)
sev_labels.append(sev)
issue_labels.append(issue)
area_labels.append(area)
if len(texts) < 40:
return None
# Encoders
sev_le, issue_le, area_le = LabelEncoder(), LabelEncoder(), LabelEncoder()
sev_ids = sev_le.fit_transform(sev_labels)
issue_ids = issue_le.fit_transform(issue_labels)
area_ids = area_le.fit_transform(area_labels)
label_encoders = {
"severity": sev_le,
"issue_type": issue_le,
"area": area_le,
}
# Dataset
class MTDataset(Dataset):
def __init__(self, tokenizer, texts, sev_ids, issue_ids, area_ids):
self.tok = tokenizer
self.texts = texts
self.sev = sev_ids
self.issue = issue_ids
self.area = area_ids
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
t = self.texts[idx]
enc = self.tok(
t,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)
item = {
"input_ids": enc["input_ids"].squeeze(0),
"attention_mask": enc["attention_mask"].squeeze(0),
"severity": int(self.sev[idx]),
"issue_type": int(self.issue[idx]),
"area": int(self.area[idx]),
}
return item
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds = MTDataset(tokenizer, texts, sev_ids, issue_ids, area_ids)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True)
# Model
num_labels_dict = {
"severity": len(sev_le.classes_),
"issue_type": len(issue_le.classes_),
"area": len(area_le.classes_),
}
model = MultiTaskDistilBert(num_labels_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Optimizer & scheduler
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
total_steps = epochs * max(1, len(dl))
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps,
)
loss_fct = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
epoch_loss = 0.0
for batch in dl:
optimizer.zero_grad(set_to_none=True)
input_ids = batch["input_ids"].to(device)
attn = batch["attention_mask"].to(device)
logits_dict = model(input_ids=input_ids, attention_mask=attn)
loss = (
loss_fct(logits_dict["severity"], batch["severity"].to(device))
+ loss_fct(logits_dict["issue_type"], batch["issue_type"].to(device))
+ loss_fct(logits_dict["area"], batch["area"].to(device))
) / 3.0
loss.backward()
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
logger.info(
f"[MultiTask] Epoch {epoch + 1}/{epochs} - loss={epoch_loss / len(dl):.4f}"
)
# Save checkpoint (PyTorch)
ckpt_path = os.path.join(output_dir, "multi_task_transformer.pt")
torch.save(
{
"model_state_dict": model.state_dict(),
"num_labels_dict": num_labels_dict,
"label_encoders": label_encoders,
},
ckpt_path,
)
logger.info(f"✓ Multi-task checkpoint saved to {ckpt_path}")
# Save encoders
with open(os.path.join(output_dir, "multi_task_label_encoders.pkl"), "wb") as f:
pickle.dump(label_encoders, f)
# Export ONNX
class Wrapper(nn.Module):
def __init__(self, mdl):
super().__init__()
self.mdl = mdl
def forward(self, input_ids, attention_mask):
out = self.mdl(input_ids=input_ids, attention_mask=attention_mask)
return out["severity"], out["issue_type"], out["area"]
wrapper = Wrapper(model).to(device)
wrapper.eval()
dummy = tokenizer(
"test input",
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
)
dummy = {k: v.to(device) for k, v in dummy.items()}
onnx_path = os.path.join(output_dir, "multi_task_transformer.onnx")
torch.onnx.export(
wrapper,
(dummy["input_ids"], dummy["attention_mask"]),
onnx_path,
input_names=["input_ids", "attention_mask"],
output_names=["severity", "issue_type", "area"],
dynamic_axes={
"input_ids": {0: "batch"},
"attention_mask": {0: "batch"},
"severity": {0: "batch"},
"issue_type": {0: "batch"},
"area": {0: "batch"},
},
opset_version=17,
)
logger.info(f"✓ ONNX multi-task model saved to {onnx_path}")