import streamlit as st import torch import os from model import LanguageModel, encode, decode st.set_page_config(page_title="Nano-Llama Shakespeare", page_icon="🎭") @st.cache_resource def load_llama_model(): device = 'cpu' model = LanguageModel().to(device) # Simpler path now that it's in the root model_path = 'model.pt' if not os.path.exists(model_path): st.error(f"Could not find model weights at {model_path}.") return None checkpoint = torch.load(model_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint) model.eval() return model st.title("🎭 Nano-Llama Shakespeare") model = load_llama_model() if model: prompt = st.text_input("Enter a prompt:", "ROMEO: ") if st.button("Generate"): context = torch.tensor([encode(prompt)], dtype=torch.long) generated_ids = model.generate(context, max_new_tokens=300) st.text_area("Result", value=decode(generated_ids[0].tolist()), height=400)