Spaces:
Sleeping
Sleeping
| from setfit import SetFitModel, SetFitTrainer | |
| from sentence_transformers.losses import CosineSimilarityLoss | |
| # Function to create a pipeline for text classification using the trained model | |
| def create_classifier(model_path): | |
| classifier = SetFitModel.from_pretrained( | |
| model_path, | |
| local_files_only=True, | |
| ) | |
| return classifier | |
| def run_setfit_training( | |
| session_id, model_id, model_name, train_dataset, batch_size, num_iterations | |
| ): | |
| model = SetFitModel.from_pretrained(model_id) | |
| # Create trainer | |
| trainer = SetFitTrainer( | |
| model=model, | |
| train_dataset=train_dataset, | |
| eval_dataset=train_dataset, | |
| loss_class=CosineSimilarityLoss, | |
| metric="accuracy", | |
| batch_size=batch_size, | |
| num_iterations=num_iterations, # The number of text pairs to generate for contrastive learning | |
| num_epochs=1, # The number of epochs to use for constrastive learning | |
| column_mapping={"text": "text", "label": "label"}, | |
| ) | |
| trainer.train() | |
| # metrics = trainer.evaluate() | |
| # accuracy = metrics["accuracy"] | |
| print(f"model used: {model_id}") | |
| print(f"train dataset: {len(train_dataset)} samples") | |
| # print(f"accuracy: {accuracy}") | |
| save_model_path = f"./models/{session_id}/{model_id}_{model_name}" | |
| trainer.model._save_pretrained( | |
| save_directory=f"./models/{session_id}/{model_id}_{model_name}" | |
| ) | |
| return save_model_path | |