File size: 1,881 Bytes
aeea5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel

SYSTEM_PROMPT = "You are a careful reasoning assistant for software and professional writing."

def extract_assistant(generated_text: str) -> str:
    m = re.split(r"<\|start_header_id\|>assistant<\|end_header_id\|>\s*", generated_text)
    if len(m) > 1:
        tail = m[-1].split("<|eot_id|>")[0]
        return tail.strip()
    return generated_text.strip()

def load(repo_id: str, base_model: str = "meta-llama/Llama-3.2-3B-Instruct"):
    tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.padding_side = "right"

    base = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    mdl = PeftModel.from_pretrained(base, repo_id)
    mdl.eval()
    gen = pipeline("text-generation", model=mdl, tokenizer=tok, device_map="auto")
    return tok, mdl, gen

def chat(gen, tok, user: str, max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9):
    msgs = [{"role":"system","content":SYSTEM_PROMPT},{"role":"user","content":user}]
    prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, return_full_text=True)
    if do_sample:
        kwargs.update(dict(do_sample=True, temperature=float(temperature), top_p=float(top_p)))
    else:
        kwargs.update(dict(do_sample=False))
    out = gen(prompt, **kwargs)[0]["generated_text"]
    return extract_assistant(out)

if __name__ == "__main__":
    REPO_ID = "Saunak359/llama-3.1-8b-reasoning-lora"
    tok, mdl, gen = load(REPO_ID)
    print(chat(gen, tok, "Write a short report on QLoRA with Introduction and Conclusion.", 220))