Spaces:
Sleeping
Sleeping
| import torch | |
| import pickle | |
| import os | |
| import gradio as gr | |
| from nanoGPT.model import GPT, GPTConfig | |
| # Load model config | |
| with open("out_ai_gf/config.pkl", "rb") as f: | |
| model_args = pickle.load(f) | |
| model = GPT(GPTConfig(**model_args)) | |
| # Load weights with ln_1 → ln1 fix | |
| checkpoint = torch.load("out_ai_gf/ckpt.pt", map_location="cpu", weights_only=False) | |
| fixed_state_dict = {} | |
| for k, v in checkpoint["model"].items(): | |
| new_key = k.replace("ln_1", "ln1").replace("ln_2", "ln2") | |
| fixed_state_dict[new_key] = v | |
| model.load_state_dict(fixed_state_dict, strict=False) | |
| model.eval() | |
| # Tokenizer | |
| with open("out_ai_gf/meta.pkl", "rb") as f: | |
| meta = pickle.load(f) | |
| itos = meta["itos"] | |
| stoi = meta["stoi"] | |
| def encode(text): return [stoi.get(c, 0) for c in text] | |
| def decode(tokens): | |
| text = ''.join([itos.get(i, '') for i in tokens]) | |
| return ''.join(c if 32 <= ord(c) <= 126 else '�' for c in text) # Filters garbage characters | |
| # Chat function | |
| def chat_fn(user_input): | |
| x = torch.tensor([encode(user_input)], dtype=torch.long) | |
| with torch.no_grad(): | |
| y = model.generate(x, max_new_tokens=100)[0].tolist() | |
| return decode(y[len(x[0]):]) | |
| # Gradio interface | |
| gr.Interface( | |
| fn=chat_fn, | |
| inputs=gr.Textbox(label="You"), | |
| outputs=gr.Textbox(label="Lin Yao 💬"), | |
| title="Lin Yao — AI Girlfriend", | |
| description="Chat live with your custom trained AI girlfriend!", | |
| theme="soft" | |
| ).launch() | |