ai-gf-chatbot / app.py
Lepish's picture
Update app.py
f2cb5aa verified
import torch
import pickle
import os
import gradio as gr
from nanoGPT.model import GPT, GPTConfig
# Load model config
with open("out_ai_gf/config.pkl", "rb") as f:
model_args = pickle.load(f)
model = GPT(GPTConfig(**model_args))
# Load weights with ln_1 → ln1 fix
checkpoint = torch.load("out_ai_gf/ckpt.pt", map_location="cpu", weights_only=False)
fixed_state_dict = {}
for k, v in checkpoint["model"].items():
new_key = k.replace("ln_1", "ln1").replace("ln_2", "ln2")
fixed_state_dict[new_key] = v
model.load_state_dict(fixed_state_dict, strict=False)
model.eval()
# Tokenizer
with open("out_ai_gf/meta.pkl", "rb") as f:
meta = pickle.load(f)
itos = meta["itos"]
stoi = meta["stoi"]
def encode(text): return [stoi.get(c, 0) for c in text]
def decode(tokens):
text = ''.join([itos.get(i, '') for i in tokens])
return ''.join(c if 32 <= ord(c) <= 126 else '�' for c in text) # Filters garbage characters
# Chat function
def chat_fn(user_input):
x = torch.tensor([encode(user_input)], dtype=torch.long)
with torch.no_grad():
y = model.generate(x, max_new_tokens=100)[0].tolist()
return decode(y[len(x[0]):])
# Gradio interface
gr.Interface(
fn=chat_fn,
inputs=gr.Textbox(label="You"),
outputs=gr.Textbox(label="Lin Yao 💬"),
title="Lin Yao — AI Girlfriend",
description="Chat live with your custom trained AI girlfriend!",
theme="soft"
).launch()