toyllama-50m / test.py
sapbot's picture
Update test.py
8fd8746 verified
import torch
from transformers import AutoTokenizer, LlamaForCausalLM
MODEL_DIR = "sapbot/toyllama-50m"
# --- Generation Settings ---
MAX_NEW_TOKENS = 150
TEMPERATURE = 0.7
TOP_P = 0.9
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running inference on: {device.upper()}")
try:
print(f"Loading model from {MODEL_DIR}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = LlamaForCausalLM.from_pretrained(MODEL_DIR)
model.to(device)
model.eval()
print("Model loaded successfully!\n")
except Exception as e:
print(f"Failed to load. Error: {e}")
return
print("=" * 60)
print("INTERACTIVE MODE: Ready! (Type 'quit' or 'exit' to stop)")
print("=" * 60)
while True:
try:
prompt = input("\n>>> Enter prompt: ")
except (KeyboardInterrupt, EOFError):
print("\nExiting...")
break
if prompt.strip().lower() in ["quit", "exit"]:
print("Goodbye!")
break
if not prompt.strip():
continue
inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs.pop("token_type_ids", None)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
)
# GPT-2 Tokenizer handles decoding perfectly out of the box!
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-" * 60)
print(generated_text.strip())
print("-" * 60)
if __name__ == "__main__":
main()