| import re | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from peft import PeftModel | |
| SYSTEM_PROMPT = "You are a careful reasoning assistant for software and professional writing." | |
| def extract_assistant(generated_text: str) -> str: | |
| m = re.split(r"<\|start_header_id\|>assistant<\|end_header_id\|>\s*", generated_text) | |
| if len(m) > 1: | |
| tail = m[-1].split("<|eot_id|>")[0] | |
| return tail.strip() | |
| return generated_text.strip() | |
| def load(repo_id: str, base_model: str = "meta-llama/Llama-3.2-3B-Instruct"): | |
| tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| tok.padding_side = "right" | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| mdl = PeftModel.from_pretrained(base, repo_id) | |
| mdl.eval() | |
| gen = pipeline("text-generation", model=mdl, tokenizer=tok, device_map="auto") | |
| return tok, mdl, gen | |
| def chat(gen, tok, user: str, max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9): | |
| msgs = [{"role":"system","content":SYSTEM_PROMPT},{"role":"user","content":user}] | |
| prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, return_full_text=True) | |
| if do_sample: | |
| kwargs.update(dict(do_sample=True, temperature=float(temperature), top_p=float(top_p))) | |
| else: | |
| kwargs.update(dict(do_sample=False)) | |
| out = gen(prompt, **kwargs)[0]["generated_text"] | |
| return extract_assistant(out) | |
| if __name__ == "__main__": | |
| REPO_ID = "Saunak359/llama-3.1-8b-reasoning-lora" | |
| tok, mdl, gen = load(REPO_ID) | |
| print(chat(gen, tok, "Write a short report on QLoRA with Introduction and Conclusion.", 220)) | |