| | |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq |
| | from datasets import Dataset |
| | from huggingface_hub import login |
| |
|
| | |
| | |
| | login(token="DUDE") |
| |
|
| | |
| | base_model = "PerceptronAI/Isaac-0.1" |
| | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
| | model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True, torch_dtype="auto") |
| |
|
| | |
| | data = [ |
| | {"input": "Hello, who are you?", "output": "I am Cass2.0, your AI assistant."}, |
| | {"input": "Tell me a joke.", "output": "Why did the robot cross the road? To recharge itself!"}, |
| | {"input": "What's your purpose?", "output": "I help you with answers, coding, and ideas as Cass2.0."}, |
| | ] |
| |
|
| | dataset = Dataset.from_list(data) |
| |
|
| | |
| | def tokenize(batch): |
| | inputs = tokenizer(batch["input"], truncation=True, padding="max_length", max_length=128) |
| | outputs = tokenizer(batch["output"], truncation=True, padding="max_length", max_length=128) |
| | inputs["labels"] = outputs["input_ids"] |
| | return inputs |
| |
|
| | tokenized_dataset = dataset.map(tokenize, batched=True) |
| |
|
| | |
| | training_args = TrainingArguments( |
| | output_dir="./cass2.0", |
| | num_train_epochs=3, |
| | per_device_train_batch_size=2, |
| | save_steps=50, |
| | save_total_limit=2, |
| | logging_steps=10, |
| | learning_rate=5e-5, |
| | fp16=True, |
| | push_to_hub=True, |
| | hub_model_id="cass2.0" |
| | ) |
| |
|
| | data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=tokenized_dataset, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator |
| | ) |
| |
|
| | |
| | print("🚀 Training Cass2.0...") |
| | trainer.train() |
| |
|
| | |
| | model.save_pretrained("./cass2.0") |
| | tokenizer.save_pretrained("./cass2.0") |
| | print("✅ Model saved locally in './cass2.0'") |
| |
|
| | |
| | trainer.push_to_hub() |
| | print("🌐 Model pushed to Hugging Face Hub as 'cass2.0'") |
| |
|