import os os.environ["CUDA_VISIBLE_DEVICES"] = "2" import inspect import torch device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") from datasets import load_dataset from huggingface_hub import notebook_login from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from trl import SFTConfig, SFTTrainer lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], ) model_id = "google/gemma-2-2b-it" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", ) model.config.use_cache = False dataset = load_dataset("tatsu-lab/alpaca", split="train") def format_alpaca_prompt(example): instruction = example["instruction"].strip() user_input = example["input"].strip() response = example["output"].strip() if user_input: prompt = ( f"### Instruction:\n{instruction}\n\n" f"### Input:\n{user_input}\n\n" "### Response:\n" ) else: prompt = f"### Instruction:\n{instruction}\n\n### Response:\n" return {"text": f"{prompt}{response}"} train_dataset = dataset.map(format_alpaca_prompt) train_dataset=train_dataset.select(range(100)) print(train_dataset) print(train_dataset[0]["text"][:300]) # Quick sanity check before fine-tuning: # take a few prompts from the train set and run base-model inference. num_preview_samples = 3 preview_dataset = train_dataset.select(range(num_preview_samples)) print(f"\nPre-finetuning preview on {num_preview_samples} samples:") comparison_rows = [] for idx, sample in enumerate(preview_dataset): full_text = sample["text"] split_token = "### Response:\n" prompt_text = full_text.split(split_token)[0] + split_token expected_response = full_text.split(split_token, 1)[1] inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"\n--- Sample {idx + 1} Prompt ---\n{prompt_text}") print(f"--- Sample {idx + 1} Base Model Output ---\n{decoded}") comparison_rows.append( { "id": idx + 1, "prompt": prompt_text, "target": expected_response, "before": decoded, } ) config_kwargs = { "output_dir": "./gemma-2-2b-it-alpaca-lora", "num_train_epochs": 1, "per_device_train_batch_size": 1, "gradient_accumulation_steps": 8, "learning_rate": 2e-4, "lr_scheduler_type": "cosine", "warmup_ratio": 0.03, "logging_steps": 10, "save_strategy": "epoch", "eval_strategy": "no", "optim": "paged_adamw_8bit", "bf16": torch.cuda.is_available(), "gradient_checkpointing": True, "packing": True, "report_to": "none", } supported_config_keys = set(inspect.signature(SFTConfig.__init__).parameters.keys()) config_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_config_keys} training_args = SFTConfig(**config_kwargs) trainer_kwargs = { "model": model, "args": training_args, "train_dataset": train_dataset, "peft_config": lora_config, "dataset_text_field": "text", "max_seq_length": 1024, } supported_trainer_keys = set(inspect.signature(SFTTrainer.__init__).parameters.keys()) trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if k in supported_trainer_keys} trainer = SFTTrainer( **trainer_kwargs, ) train_result = trainer.train() adapter_out = "./gemma-2-2b-it-alpaca-lora/final_adapter" trainer.model.save_pretrained(adapter_out) tokenizer.save_pretrained(adapter_out) print(f"Saved LoRA adapter to: {adapter_out}") print("\nPost-finetuning comparison on same samples:") for row in comparison_rows: inputs = tokenizer(row["prompt"], return_tensors="pt").to(trainer.model.device) with torch.no_grad(): outputs = trainer.model.generate( **inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id, ) after_decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"\n=== Sample {row['id']} ===") print(f"Prompt:\n{row['prompt']}") print(f"\nGround Truth Response:\n{row['target']}") print(f"\nBefore Fine-tuning:\n{row['before']}") print(f"\nAfter Fine-tuning:\n{after_decoded}") prompt = "### Instruction:\nExplain photosynthesis in simple words.\n\n### Response:\n" inputs = tokenizer(prompt, return_tensors="pt").to(trainer.model.device) with torch.no_grad(): outputs = trainer.model.generate( **inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id, ) print(tokenizer.decode(outputs[0], skip_special_tokens=True))