readctrl / habib /th.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import inspect
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
from datasets import load_dataset
from huggingface_hub import notebook_login
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
)
model_id = "google/gemma-2-2b-it"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
)
model.config.use_cache = False
dataset = load_dataset("tatsu-lab/alpaca", split="train")
def format_alpaca_prompt(example):
instruction = example["instruction"].strip()
user_input = example["input"].strip()
response = example["output"].strip()
if user_input:
prompt = (
f"### Instruction:\n{instruction}\n\n"
f"### Input:\n{user_input}\n\n"
"### Response:\n"
)
else:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
return {"text": f"{prompt}{response}"}
train_dataset = dataset.map(format_alpaca_prompt)
train_dataset=train_dataset.select(range(100))
print(train_dataset)
print(train_dataset[0]["text"][:300])
# Quick sanity check before fine-tuning:
# take a few prompts from the train set and run base-model inference.
num_preview_samples = 3
preview_dataset = train_dataset.select(range(num_preview_samples))
print(f"\nPre-finetuning preview on {num_preview_samples} samples:")
comparison_rows = []
for idx, sample in enumerate(preview_dataset):
full_text = sample["text"]
split_token = "### Response:\n"
prompt_text = full_text.split(split_token)[0] + split_token
expected_response = full_text.split(split_token, 1)[1]
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\n--- Sample {idx + 1} Prompt ---\n{prompt_text}")
print(f"--- Sample {idx + 1} Base Model Output ---\n{decoded}")
comparison_rows.append(
{
"id": idx + 1,
"prompt": prompt_text,
"target": expected_response,
"before": decoded,
}
)
config_kwargs = {
"output_dir": "./gemma-2-2b-it-alpaca-lora",
"num_train_epochs": 1,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"learning_rate": 2e-4,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.03,
"logging_steps": 10,
"save_strategy": "epoch",
"eval_strategy": "no",
"optim": "paged_adamw_8bit",
"bf16": torch.cuda.is_available(),
"gradient_checkpointing": True,
"packing": True,
"report_to": "none",
}
supported_config_keys = set(inspect.signature(SFTConfig.__init__).parameters.keys())
config_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_config_keys}
training_args = SFTConfig(**config_kwargs)
trainer_kwargs = {
"model": model,
"args": training_args,
"train_dataset": train_dataset,
"peft_config": lora_config,
"dataset_text_field": "text",
"max_seq_length": 1024,
}
supported_trainer_keys = set(inspect.signature(SFTTrainer.__init__).parameters.keys())
trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if k in supported_trainer_keys}
trainer = SFTTrainer(
**trainer_kwargs,
)
train_result = trainer.train()
adapter_out = "./gemma-2-2b-it-alpaca-lora/final_adapter"
trainer.model.save_pretrained(adapter_out)
tokenizer.save_pretrained(adapter_out)
print(f"Saved LoRA adapter to: {adapter_out}")
print("\nPost-finetuning comparison on same samples:")
for row in comparison_rows:
inputs = tokenizer(row["prompt"], return_tensors="pt").to(trainer.model.device)
with torch.no_grad():
outputs = trainer.model.generate(
**inputs,
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
)
after_decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\n=== Sample {row['id']} ===")
print(f"Prompt:\n{row['prompt']}")
print(f"\nGround Truth Response:\n{row['target']}")
print(f"\nBefore Fine-tuning:\n{row['before']}")
print(f"\nAfter Fine-tuning:\n{after_decoded}")
prompt = "### Instruction:\nExplain photosynthesis in simple words.\n\n### Response:\n"
inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device)
with torch.no_grad():
outputs = trainer.model.generate(
**inputs,
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))