Quark-0.5M / inference.py
LH-Tech-AI's picture
Create inference.py
d93ac5d verified
print("[*] Loading libraries...")
import torch
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
model_path = "./llama-sub-1m-final"
print("[*] Loading tokenizer...")
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path)
print("[*] Loading model...")
model = LlamaForCausalLM.from_pretrained(model_path)
model.eval()
prompt = "Artificial intelligence is "
print(f"[*] Prompt: {prompt!r}")
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=150,
do_sample=True,
temperature=0.35,
top_p=0.85,
repetition_penalty=1.2,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print("[*] Output:", tokenizer.decode(outputs[0], skip_special_tokens=True))