Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +9 -12
train_model.py
CHANGED
|
@@ -70,9 +70,9 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
| 70 |
# Check if dataset_name includes a configuration
|
| 71 |
if '/' in dataset_name:
|
| 72 |
dataset, config = dataset_name.split('/', 1)
|
| 73 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 74 |
else:
|
| 75 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 76 |
logging.info("Dataset loaded successfully for generation task.")
|
| 77 |
def tokenize_function(examples):
|
| 78 |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
|
@@ -185,6 +185,8 @@ def main():
|
|
| 185 |
if tokenizer.pad_token is None:
|
| 186 |
logging.info("Setting pad_token to eos_token.")
|
| 187 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 188 |
model = initialize_model(
|
| 189 |
task=args.task,
|
| 190 |
model_name=args.model_name,
|
|
@@ -195,7 +197,10 @@ def main():
|
|
| 195 |
attention_heads=args.attention_heads
|
| 196 |
)
|
| 197 |
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
| 198 |
else:
|
|
|
|
|
|
|
| 199 |
model = initialize_model(
|
| 200 |
task=args.task,
|
| 201 |
model_name=args.model_name,
|
|
@@ -206,7 +211,7 @@ def main():
|
|
| 206 |
attention_heads=args.attention_heads
|
| 207 |
)
|
| 208 |
except Exception as e:
|
| 209 |
-
logging.error(f"Error initializing tokenizer: {str(e)}")
|
| 210 |
raise e
|
| 211 |
|
| 212 |
# Load and prepare dataset
|
|
@@ -221,9 +226,6 @@ def main():
|
|
| 221 |
logging.error("Failed to load and prepare dataset.")
|
| 222 |
raise e
|
| 223 |
|
| 224 |
-
# Initialize model (Already initialized above)
|
| 225 |
-
# model = initialize_model(...) # Moved above to handle pad_token
|
| 226 |
-
|
| 227 |
# Define data collator
|
| 228 |
if args.task == "generation":
|
| 229 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
@@ -245,7 +247,7 @@ def main():
|
|
| 245 |
learning_rate=5e-4,
|
| 246 |
remove_unused_columns=False,
|
| 247 |
push_to_hub=False # We'll handle pushing manually
|
| 248 |
-
|
| 249 |
)
|
| 250 |
elif args.task == "classification":
|
| 251 |
training_args = TrainingArguments(
|
|
@@ -313,8 +315,3 @@ def main():
|
|
| 313 |
|
| 314 |
if __name__ == "__main__":
|
| 315 |
main()
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
| 70 |
# Check if dataset_name includes a configuration
|
| 71 |
if '/' in dataset_name:
|
| 72 |
dataset, config = dataset_name.split('/', 1)
|
| 73 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
|
| 74 |
else:
|
| 75 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
|
| 76 |
logging.info("Dataset loaded successfully for generation task.")
|
| 77 |
def tokenize_function(examples):
|
| 78 |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
|
|
|
| 185 |
if tokenizer.pad_token is None:
|
| 186 |
logging.info("Setting pad_token to eos_token.")
|
| 187 |
tokenizer.pad_token = tokenizer.eos_token
|
| 188 |
+
logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
|
| 189 |
+
# Resize model's token embeddings after setting pad_token
|
| 190 |
model = initialize_model(
|
| 191 |
task=args.task,
|
| 192 |
model_name=args.model_name,
|
|
|
|
| 197 |
attention_heads=args.attention_heads
|
| 198 |
)
|
| 199 |
model.resize_token_embeddings(len(tokenizer))
|
| 200 |
+
logging.info("Resized token embeddings to accommodate pad_token.")
|
| 201 |
else:
|
| 202 |
+
logging.info(f"Tokenizer already has pad_token set to: {tokenizer.pad_token}")
|
| 203 |
+
# Initialize model normally
|
| 204 |
model = initialize_model(
|
| 205 |
task=args.task,
|
| 206 |
model_name=args.model_name,
|
|
|
|
| 211 |
attention_heads=args.attention_heads
|
| 212 |
)
|
| 213 |
except Exception as e:
|
| 214 |
+
logging.error(f"Error initializing tokenizer or model: {str(e)}")
|
| 215 |
raise e
|
| 216 |
|
| 217 |
# Load and prepare dataset
|
|
|
|
| 226 |
logging.error("Failed to load and prepare dataset.")
|
| 227 |
raise e
|
| 228 |
|
|
|
|
|
|
|
|
|
|
| 229 |
# Define data collator
|
| 230 |
if args.task == "generation":
|
| 231 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
|
|
| 247 |
learning_rate=5e-4,
|
| 248 |
remove_unused_columns=False,
|
| 249 |
push_to_hub=False # We'll handle pushing manually
|
| 250 |
+
|
| 251 |
)
|
| 252 |
elif args.task == "classification":
|
| 253 |
training_args = TrainingArguments(
|
|
|
|
| 315 |
|
| 316 |
if __name__ == "__main__":
|
| 317 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|