File size: 1,018 Bytes
1ef0f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import streamlit as st
import torch
import os
from model import LanguageModel, encode, decode

st.set_page_config(page_title="Nano-Llama Shakespeare", page_icon="🎭")

@st.cache_resource
def load_llama_model():
    device = 'cpu'
    model = LanguageModel().to(device)
    
    # Simpler path now that it's in the root
    model_path = 'model.pt' 
    
    if not os.path.exists(model_path):
        st.error(f"Could not find model weights at {model_path}.")
        return None

    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint)
    model.eval()
    return model

st.title("🎭 Nano-Llama Shakespeare")
model = load_llama_model()

if model:
    prompt = st.text_input("Enter a prompt:", "ROMEO: ")
    if st.button("Generate"):
        context = torch.tensor([encode(prompt)], dtype=torch.long)
        generated_ids = model.generate(context, max_new_tokens=300)
        st.text_area("Result", value=decode(generated_ids[0].tolist()), height=400)