Adapters
Safetensors
llama
tarantulas commited on
Commit
b5d2501
·
verified ·
1 Parent(s): 57edda7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -0
README.md CHANGED
@@ -22,6 +22,58 @@ This model combines the base `full_finetuned_llama3b` with LoRA fine-tuning on:
22
  from transformers import AutoTokenizer, AutoModelForCausalLM
23
  model = AutoModelForCausalLM.from_pretrained("your-hf-username/full_finetuned_llama3b")
24
  tokenizer = AutoTokenizer.from_pretrained("ai-factory/giant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ```
26
 
27
  ## 👤 Authors
 
22
  from transformers import AutoTokenizer, AutoModelForCausalLM
23
  model = AutoModelForCausalLM.from_pretrained("your-hf-username/full_finetuned_llama3b")
24
  tokenizer = AutoTokenizer.from_pretrained("ai-factory/giant")
25
+ # Load base model
26
+ base_model = AutoModelForCausalLM.from_pretrained(
27
+ BASE_MODEL,
28
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ device_map="auto",
30
+ trust_remote_code=True,
31
+ use_safetensors=True,
32
+ local_files_only=True
33
+ )
34
+
35
+ # Apply LoRA
36
+ peft_config = LoraConfig(
37
+ task_type=TaskType.CAUSAL_LM,
38
+ r=8,
39
+ lora_alpha=32,
40
+ lora_dropout=0.05,
41
+ bias="none",
42
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
43
+ )
44
+ model = get_peft_model(base_model, peft_config)
45
+ model.eval()
46
+ if torch.cuda.is_available():
47
+ model = model.cuda()
48
+
49
+ # Load streaming datasets
50
+ arxiv = load_dataset("ai-factory/red_pajama_subset_arxiv_subset", split="train", streaming=True)
51
+ glaive = load_dataset("ai-factory/glaiveai-reasoning-v1-20m-chat", split="train", streaming=True)
52
+
53
+ def tokenize(example):
54
+ return tokenizer(example["text"], truncation=True, max_length=4096)
55
+
56
+ # Tokenize small samples
57
+ tokenized_arxiv = map(tokenize, islice(arxiv, args.sample_size))
58
+ tokenized_glaive = map(tokenize, islice(glaive, args.sample_size))
59
+
60
+ # Run forward + backward pass (init LoRA weights)
61
+ print("🔥 Training one step to initialize LoRA...")
62
+ for i, sample in enumerate(tokenized_arxiv):
63
+ if not sample.get("input_ids"):
64
+ continue
65
+ ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(model.device)
66
+ labels = ids.clone()
67
+ loss = model(input_ids=ids, labels=labels).loss
68
+ loss.backward()
69
+ break
70
+
71
+ # Merge LoRA and save
72
+ print("🔁 Merging adapter into base model...")
73
+ merged_model = model.merge_and_unload()
74
+ merged_model.save_pretrained(SAVE_DIR, safe_serialization=True)
75
+ tokenizer.save_pretrained(SAVE_DIR)
76
+ print(f"✅ Merged model saved to {SAVE_DIR}")
77
  ```
78
 
79
  ## 👤 Authors