toyllama-30m / test.py
sapbot's picture
Update test.py
491217a verified
import torch
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
MODEL_DIR = "sapbot/toyllama-30m"
# --- Generation Settings ---
MAX_NEW_TOKENS = 150 # Increased slightly so it can finish thoughts
TEMPERATURE = 0.7
TOP_P = 0.9
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running inference on: {device.upper()}")
try:
print(f"Loading model from {MODEL_DIR}...")
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_DIR)
model = LlamaForCausalLM.from_pretrained(MODEL_DIR)
model.to(device)
model.eval()
print("Model loaded successfully!\n")
except Exception as e:
print(f"Failed to load. Error: {e}")
return
print("=" * 60)
print("INTERACTIVE MODE: Ready! (Type 'quit' or 'exit' to stop)")
print("=" * 60)
# Get special tokens to ignore during decoding
special_tokens = {tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token, None}
while True:
try:
prompt = input(f"\n>>> Enter prompt (temp: {TEMPERATURE}): ")
except (KeyboardInterrupt, EOFError):
print("\nExiting...")
break
if prompt.strip().lower() in ["quit", "exit"]:
print("Goodbye!")
break
if not prompt.strip():
continue
inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs.pop("token_type_ids", None)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
)
token_ids = outputs[0].tolist()
raw_tokens = tokenizer.convert_ids_to_tokens(token_ids)
clean_tokens = [tok for tok in raw_tokens if tok not in special_tokens]
raw_text = "".join(clean_tokens)
generated_text = raw_text.replace("Ġ", " ").replace("Ċ", "\n")
print("-" * 60)
print(generated_text.strip())
print("-" * 60)
if __name__ == "__main__":
main()