Reason-RAG / inference.py
Saunak359's picture
Upload folder using huggingface_hub
aeea5e9 verified
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
SYSTEM_PROMPT = "You are a careful reasoning assistant for software and professional writing."
def extract_assistant(generated_text: str) -> str:
m = re.split(r"<\|start_header_id\|>assistant<\|end_header_id\|>\s*", generated_text)
if len(m) > 1:
tail = m[-1].split("<|eot_id|>")[0]
return tail.strip()
return generated_text.strip()
def load(repo_id: str, base_model: str = "meta-llama/Llama-3.2-3B-Instruct"):
tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "right"
base = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float16,
device_map="auto",
)
mdl = PeftModel.from_pretrained(base, repo_id)
mdl.eval()
gen = pipeline("text-generation", model=mdl, tokenizer=tok, device_map="auto")
return tok, mdl, gen
def chat(gen, tok, user: str, max_new_tokens=256, do_sample=False, temperature=0.7, top_p=0.9):
msgs = [{"role":"system","content":SYSTEM_PROMPT},{"role":"user","content":user}]
prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
kwargs = dict(max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, return_full_text=True)
if do_sample:
kwargs.update(dict(do_sample=True, temperature=float(temperature), top_p=float(top_p)))
else:
kwargs.update(dict(do_sample=False))
out = gen(prompt, **kwargs)[0]["generated_text"]
return extract_assistant(out)
if __name__ == "__main__":
REPO_ID = "Saunak359/llama-3.1-8b-reasoning-lora"
tok, mdl, gen = load(REPO_ID)
print(chat(gen, tok, "Write a short report on QLoRA with Introduction and Conclusion.", 220))