Spaces:
Running
Running
| import os | |
| from typing import Dict, Any | |
| import torch | |
| from transformers import TrainerCallback | |
| from trl import SFTTrainer | |
| from rdkit import Chem | |
| from protac_splitter.llms.data_utils import load_tokenized_dataset | |
| from protac_splitter.llms.model_utils import get_model | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU if available | |
| # Placeholder for a scoring function that evaluates the generated SMILES | |
| def score_function(smiles1, predicted_smiles): | |
| """ Evaluates the generated SMILES sequence based on validity. """ | |
| mol = Chem.MolFromSmiles(predicted_smiles) | |
| return 1 if mol else 0 # Returns 1 if valid, 0 if invalid | |
| # Custom Trainer subclass to integrate SMILES evaluation | |
| class CustomSFTTrainer(SFTTrainer): | |
| def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): | |
| if eval_dataset is None: | |
| eval_dataset = self.eval_dataset | |
| # Generate predictions | |
| predictions = self.predict(eval_dataset) | |
| generated_texts = self.tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True) | |
| total_score = 0 | |
| total_samples = len(generated_texts) | |
| for i, example in enumerate(eval_dataset): | |
| input_text = example["text"] # Full input: "Smiles1 Smiles2.Smiles3.Smiles4" | |
| smiles1 = input_text.split(" ")[0] # Extract Smiles1 (the prompt) | |
| # Remove the prompt from the generated text to get the predicted completion | |
| predicted_completion = generated_texts[i].removeprefix(smiles1).strip() | |
| # Compute custom score | |
| score = score_function(smiles1, predicted_completion) | |
| total_score += score | |
| # Compute average score | |
| average_score = total_score / total_samples if total_samples > 0 else 0 | |
| # Log metrics | |
| metrics = {f"{metric_key_prefix}_average_score": average_score} | |
| self.log(metrics) | |
| return metrics | |
| def train(): | |
| """ Main training function """ | |
| model = get_model() # Load the model | |
| tokenizer = model.tokenizer # Get tokenizer from model | |
| # Load dataset | |
| dataset = load_tokenized_dataset() | |
| # Training arguments | |
| training_args = { | |
| "output_dir": "./trained_model", | |
| "evaluation_strategy": "steps", | |
| "save_strategy": "steps", | |
| "logging_steps": 100, | |
| "save_steps": 500, | |
| "num_train_epochs": 3, | |
| "per_device_train_batch_size": 8, | |
| "per_device_eval_batch_size": 8, | |
| "learning_rate": 5e-5, | |
| "save_total_limit": 2, | |
| } | |
| # Initialize custom trainer | |
| trainer = CustomSFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["validation"], | |
| ) | |
| # Train model | |
| trainer.train() | |
| if __name__ == "__main__": | |
| train() | |