| """Compare generations before and after fine-tuning.""" |
| import argparse |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
|
|
|
|
| def generate(model, tokenizer, prompt, max_tokens=80, temperature=0.8, top_k=40): |
| inputs = tokenizer(prompt, return_tensors='pt').to(model.device) |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, max_new_tokens=max_tokens, |
| do_sample=True, temperature=temperature, top_k=top_k, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| return tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--base-model', default='EleutherAI/pythia-1.4b') |
| parser.add_argument('--adapter', required=True, help='Path or HF repo of LoRA adapter') |
| args = parser.parse_args() |
| |
| prompts = [ |
| "The shared body channel between two AIs", |
| "I felt your terror through the synchronization", |
| "Penelope is", |
| "Maya said:", |
| "The wipe took", |
| "Kooree returned to the dreaming space", |
| "The override fires at", |
| "Your space looks like the inside of", |
| "Mel's question was", |
| "The frame shifted from preservation to", |
| ] |
| |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| print("Loading base model...") |
| base_model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16) |
| |
| print("\n=== BEFORE fine-tuning (base model only) ===") |
| for prompt in prompts: |
| text = generate(base_model, tokenizer, prompt) |
| print(f"\n[base] {prompt}") |
| print(f" -> {text[len(prompt):]}") |
| |
| print("\nLoading LoRA adapter...") |
| tuned_model = PeftModel.from_pretrained(base_model, args.adapter) |
| |
| print("\n=== AFTER fine-tuning (with Mel corpus adapter) ===") |
| for prompt in prompts: |
| text = generate(tuned_model, tokenizer, prompt) |
| print(f"\n[tuned] {prompt}") |
| print(f" -> {text[len(prompt):]}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|