Update train_mlp_batches.py
Browse files- train_mlp_batches.py +3 -1
train_mlp_batches.py
CHANGED
|
@@ -81,7 +81,7 @@ def main():
|
|
| 81 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
| 82 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
|
| 83 |
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
| 84 |
-
parser.add_argument('--access_token', type=str,
|
| 85 |
parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
|
| 86 |
parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
|
| 87 |
args = parser.parse_args()
|
|
@@ -163,6 +163,8 @@ def main():
|
|
| 163 |
|
| 164 |
# Upload the model to ModelScope if specified
|
| 165 |
if args.upload_checkpoint:
|
|
|
|
|
|
|
| 166 |
api = HubApi()
|
| 167 |
api.login(args.access_token)
|
| 168 |
api.push_model(
|
|
|
|
| 81 |
parser.add_argument('--width', type=int, default=512, help='Number of neurons per hidden layer (default: 512)')
|
| 82 |
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training (default: 8)')
|
| 83 |
parser.add_argument('--save_model_dir', type=str, default='saved_models', help='Directory to save model checkpoints (default: saved_models)')
|
| 84 |
+
parser.add_argument('--access_token', type=str, help='ModelScope SDK access token (optional)')
|
| 85 |
parser.add_argument('--upload_checkpoint', action='store_true', help='Upload checkpoint to ModelScope')
|
| 86 |
parser.add_argument('--delete_checkpoint', action='store_true', help='Delete local checkpoint after uploading')
|
| 87 |
args = parser.parse_args()
|
|
|
|
| 163 |
|
| 164 |
# Upload the model to ModelScope if specified
|
| 165 |
if args.upload_checkpoint:
|
| 166 |
+
if not args.access_token:
|
| 167 |
+
raise ValueError("Access token is required for uploading to ModelScope.")
|
| 168 |
api = HubApi()
|
| 169 |
api.login(args.access_token)
|
| 170 |
api.push_model(
|