Oranblock commited on
Commit
53096ab
·
verified ·
1 Parent(s): b3bf396

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from datasets import load_dataset
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
6
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
7
+ from huggingface_hub import HfFolder
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+
12
+ # Set cache directory to a writable location
13
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
14
+
15
+ def compute_metrics(pred):
16
+ labels = pred.label_ids
17
+ preds = pred.predictions.argmax(-1)
18
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
19
+ acc = accuracy_score(labels, preds)
20
+ return {
21
+ 'accuracy': acc,
22
+ 'f1': f1,
23
+ 'precision': precision,
24
+ 'recall': recall
25
+ }
26
+
27
+ def setup_training():
28
+ logging.info("Starting the training setup process")
29
+
30
+ # Load configuration
31
+ with open('config.json', 'r') as f:
32
+ config = json.load(f)
33
+
34
+ logging.info(f"Loaded configuration: {config}")
35
+
36
+ # Load your dataset
37
+ logging.info("Loading the MarbleX dataset")
38
+ dataset = load_dataset("Oranblock/marblex_dataset")
39
+ logging.info(f"Dataset loaded. Train size: {len(dataset['train'])}, Validation size: {len(dataset['validation'])}")
40
+
41
+ # Load tokenizer and model
42
+ logging.info(f"Loading tokenizer and model: {config['model_name']}")
43
+ tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
44
+ model = AutoModelForSequenceClassification.from_pretrained(
45
+ config['model_name'],
46
+ num_labels=len(dataset['train'].features[config['target_column']].names)
47
+ )
48
+
49
+ # Tokenize the dataset
50
+ logging.info("Tokenizing the dataset")
51
+ def tokenize_function(examples):
52
+ return tokenizer(examples[config['text_column']], padding="max_length", truncation=True)
53
+
54
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
55
+ logging.info("Dataset tokenization completed")
56
+
57
+ # Set up training arguments
58
+ logging.info("Setting up training arguments")
59
+ training_args = TrainingArguments(
60
+ output_dir="./results",
61
+ num_train_epochs=config['num_train_epochs'],
62
+ per_device_train_batch_size=config['per_device_train_batch_size'],
63
+ per_device_eval_batch_size=config['per_device_eval_batch_size'],
64
+ warmup_ratio=config['warmup_ratio'],
65
+ weight_decay=config['weight_decay'],
66
+ learning_rate=config['learning_rate'],
67
+ fp16=config['fp16'],
68
+ evaluation_strategy="epoch",
69
+ save_strategy="epoch",
70
+ load_best_model_at_end=True,
71
+ push_to_hub=config['push_to_hub'],
72
+ hub_model_id=config['hub_model_id'],
73
+ logging_dir='./logs',
74
+ logging_steps=100,
75
+ )
76
+
77
+ # Initialize Trainer
78
+ logging.info("Initializing Trainer")
79
+ trainer = Trainer(
80
+ model=model,
81
+ args=training_args,
82
+ train_dataset=tokenized_datasets["train"],
83
+ eval_dataset=tokenized_datasets["validation"],
84
+ tokenizer=tokenizer,
85
+ compute_metrics=compute_metrics
86
+ )
87
+
88
+ # Start training
89
+ logging.info("Starting the training process")
90
+ trainer.train()
91
+
92
+ # Evaluate the model
93
+ logging.info("Evaluating the model")
94
+ eval_results = trainer.evaluate()
95
+ logging.info(f"Evaluation results: {eval_results}")
96
+
97
+ # Push model to hub if configured
98
+ if config['push_to_hub']:
99
+ logging.info("Pushing model to Hugging Face Hub")
100
+ trainer.push_to_hub()
101
+ logging.info(f"Model pushed to {config['hub_model_id']}")
102
+
103
+ logging.info("Training process completed")
104
+
105
+ if __name__ == "__main__":
106
+ # Set Hugging Face token
107
+ hf_token = os.environ.get('HF_TOKEN')
108
+ if hf_token:
109
+ HfFolder.save_token(hf_token)
110
+ logging.info("Hugging Face token set")
111
+ else:
112
+ logging.warning("HF_TOKEN not found in environment variables")
113
+
114
+ setup_training()