Spaces:
Runtime error
Runtime error
File size: 1,018 Bytes
1ef0f11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import streamlit as st
import torch
import os
from model import LanguageModel, encode, decode
st.set_page_config(page_title="Nano-Llama Shakespeare", page_icon="🎭")
@st.cache_resource
def load_llama_model():
device = 'cpu'
model = LanguageModel().to(device)
# Simpler path now that it's in the root
model_path = 'model.pt'
if not os.path.exists(model_path):
st.error(f"Could not find model weights at {model_path}.")
return None
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
return model
st.title("🎭 Nano-Llama Shakespeare")
model = load_llama_model()
if model:
prompt = st.text_input("Enter a prompt:", "ROMEO: ")
if st.button("Generate"):
context = torch.tensor([encode(prompt)], dtype=torch.long)
generated_ids = model.generate(context, max_new_tokens=300)
st.text_area("Result", value=decode(generated_ids[0].tolist()), height=400) |