import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from peft import PeftModel SYSTEM_PROMPT = "You are a careful reasoning assistant for software and professional writing." def extract_assistant(generated_text: str) -> str: m = re.split(r"<\|start_header_id\|>assistant<\|end_header_id\|>\s*", generated_text) if len(m) > 1: tail = m[-1].split("<|eot_id|>")[0] return tail.strip() return generated_text.strip() def load(repo_id: str, base_model: str = "meta-llama/Llama-3.2-3B-Instruct"): tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = "right" base = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.float16, device_map="auto", ) mdl = PeftModel.from_pretrained(base, repo_id) mdl.eval() gen = pipeline("text-generation", model=mdl, tokenizer=tok, device_map="auto") return tok, mdl, gen def chat(gen, tok, user: str, max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9): msgs = [{"role":"system","content":SYSTEM_PROMPT},{"role":"user","content":user}] prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, return_full_text=True) if do_sample: kwargs.update(dict(do_sample=True, temperature=float(temperature), top_p=float(top_p))) else: kwargs.update(dict(do_sample=False)) out = gen(prompt, **kwargs)[0]["generated_text"] return extract_assistant(out) if __name__ == "__main__": REPO_ID = "Saunak359/llama-3.1-8b-reasoning-lora" tok, mdl, gen = load(REPO_ID) print(chat(gen, tok, "Write a short report on QLoRA with Introduction and Conclusion.", 220))