|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import numpy as np |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer |
|
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
|
|
from huggingface_hub import HfFolder |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' |
|
|
|
|
|
def compute_metrics(pred): |
|
|
labels = pred.label_ids |
|
|
preds = pred.predictions.argmax(-1) |
|
|
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') |
|
|
acc = accuracy_score(labels, preds) |
|
|
return { |
|
|
'accuracy': acc, |
|
|
'f1': f1, |
|
|
'precision': precision, |
|
|
'recall': recall |
|
|
} |
|
|
|
|
|
def setup_training(): |
|
|
logging.info("Starting the training setup process") |
|
|
|
|
|
|
|
|
with open('config.json', 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
logging.info(f"Loaded configuration: {config}") |
|
|
|
|
|
|
|
|
logging.info("Loading the MarbleX dataset") |
|
|
dataset = load_dataset("Oranblock/marblex_dataset", "config1") |
|
|
|
|
|
|
|
|
logging.info(f"Dataset columns: {dataset['train'].column_names}") |
|
|
|
|
|
|
|
|
unique_labels = dataset['train'].unique(config['target_column']) |
|
|
num_labels = len(unique_labels) |
|
|
logging.info(f"Number of unique labels: {num_labels}") |
|
|
|
|
|
logging.info(f"Dataset loaded. Train size: {len(dataset['train'])}, Test size: {len(dataset['test'])}") |
|
|
|
|
|
|
|
|
logging.info(f"Loading tokenizer and model: {config['model_name']}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(config['model_name']) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
config['model_name'], |
|
|
num_labels=num_labels |
|
|
) |
|
|
|
|
|
|
|
|
logging.info("Tokenizing the dataset") |
|
|
def tokenize_function(examples): |
|
|
|
|
|
features = np.stack([examples[col] for col in config['text_columns']], axis=1) |
|
|
return tokenizer(features.tolist(), padding="max_length", truncation=True) |
|
|
|
|
|
tokenized_datasets = dataset.map(tokenize_function, batched=True) |
|
|
logging.info("Dataset tokenization completed") |
|
|
|
|
|
|
|
|
logging.info("Setting up training arguments") |
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
num_train_epochs=config['num_train_epochs'], |
|
|
per_device_train_batch_size=config['per_device_train_batch_size'], |
|
|
per_device_eval_batch_size=config['per_device_eval_batch_size'], |
|
|
warmup_ratio=config['warmup_ratio'], |
|
|
weight_decay=config['weight_decay'], |
|
|
learning_rate=config['learning_rate'], |
|
|
fp16=config['fp16'], |
|
|
evaluation_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
load_best_model_at_end=True, |
|
|
push_to_hub=config['push_to_hub'], |
|
|
hub_model_id=config['hub_model_id'], |
|
|
logging_dir='./logs', |
|
|
logging_steps=100, |
|
|
) |
|
|
|
|
|
|
|
|
logging.info("Initializing Trainer") |
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_datasets["train"], |
|
|
eval_dataset=tokenized_datasets["test"], |
|
|
tokenizer=tokenizer, |
|
|
compute_metrics=compute_metrics |
|
|
) |
|
|
|
|
|
|
|
|
logging.info("Starting the training process") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
logging.info("Evaluating the model") |
|
|
eval_results = trainer.evaluate() |
|
|
logging.info(f"Evaluation results: {eval_results}") |
|
|
|
|
|
|
|
|
if config['push_to_hub']: |
|
|
logging.info("Pushing model to Hugging Face Hub") |
|
|
trainer.push_to_hub() |
|
|
logging.info(f"Model pushed to {config['hub_model_id']}") |
|
|
|
|
|
logging.info("Training process completed") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
hf_token = os.environ.get('HF_TOKEN') |
|
|
if hf_token: |
|
|
HfFolder.save_token(hf_token) |
|
|
logging.info("Hugging Face token set") |
|
|
else: |
|
|
logging.warning("HF_TOKEN not found in environment variables") |
|
|
|
|
|
setup_training() |