cazyundee's picture
Update app.py
1ef0f11 verified
import streamlit as st
import torch
import os
from model import LanguageModel, encode, decode
st.set_page_config(page_title="Nano-Llama Shakespeare", page_icon="🎭")
@st.cache_resource
def load_llama_model():
device = 'cpu'
model = LanguageModel().to(device)
# Simpler path now that it's in the root
model_path = 'model.pt'
if not os.path.exists(model_path):
st.error(f"Could not find model weights at {model_path}.")
return None
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
return model
st.title("🎭 Nano-Llama Shakespeare")
model = load_llama_model()
if model:
prompt = st.text_input("Enter a prompt:", "ROMEO: ")
if st.button("Generate"):
context = torch.tensor([encode(prompt)], dtype=torch.long)
generated_ids = model.generate(context, max_new_tokens=300)
st.text_area("Result", value=decode(generated_ids[0].tolist()), height=400)