Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import os | |
| from model import LanguageModel, encode, decode | |
| st.set_page_config(page_title="Nano-Llama Shakespeare", page_icon="🎭") | |
| def load_llama_model(): | |
| device = 'cpu' | |
| model = LanguageModel().to(device) | |
| # Simpler path now that it's in the root | |
| model_path = 'model.pt' | |
| if not os.path.exists(model_path): | |
| st.error(f"Could not find model weights at {model_path}.") | |
| return None | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| return model | |
| st.title("🎭 Nano-Llama Shakespeare") | |
| model = load_llama_model() | |
| if model: | |
| prompt = st.text_input("Enter a prompt:", "ROMEO: ") | |
| if st.button("Generate"): | |
| context = torch.tensor([encode(prompt)], dtype=torch.long) | |
| generated_ids = model.generate(context, max_new_tokens=300) | |
| st.text_area("Result", value=decode(generated_ids[0].tolist()), height=400) |