Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +4 -4
train_model.py
CHANGED
|
@@ -70,18 +70,18 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
| 70 |
# Check if dataset_name includes a configuration
|
| 71 |
if '/' in dataset_name:
|
| 72 |
dataset, config = dataset_name.split('/', 1)
|
| 73 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train'
|
| 74 |
else:
|
| 75 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train'
|
| 76 |
logging.info("Dataset loaded successfully for generation task.")
|
| 77 |
def tokenize_function(examples):
|
| 78 |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
| 79 |
elif task == "classification":
|
| 80 |
if '/' in dataset_name:
|
| 81 |
dataset, config = dataset_name.split('/', 1)
|
| 82 |
-
dataset = load_dataset(dataset, config, split='train'
|
| 83 |
else:
|
| 84 |
-
dataset = load_dataset(dataset_name, split='train'
|
| 85 |
logging.info("Dataset loaded successfully for classification task.")
|
| 86 |
# Assuming the dataset has 'text' and 'label' columns
|
| 87 |
def tokenize_function(examples):
|
|
|
|
| 70 |
# Check if dataset_name includes a configuration
|
| 71 |
if '/' in dataset_name:
|
| 72 |
dataset, config = dataset_name.split('/', 1)
|
| 73 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 74 |
else:
|
| 75 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
| 76 |
logging.info("Dataset loaded successfully for generation task.")
|
| 77 |
def tokenize_function(examples):
|
| 78 |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
| 79 |
elif task == "classification":
|
| 80 |
if '/' in dataset_name:
|
| 81 |
dataset, config = dataset_name.split('/', 1)
|
| 82 |
+
dataset = load_dataset(dataset, config, split='train')
|
| 83 |
else:
|
| 84 |
+
dataset = load_dataset(dataset_name, split='train')
|
| 85 |
logging.info("Dataset loaded successfully for classification task.")
|
| 86 |
# Assuming the dataset has 'text' and 'label' columns
|
| 87 |
def tokenize_function(examples):
|