Niche-train / app.py
Neon-AI's picture
Update app.py
c2ec7cc verified
import os
import json
import torch
import streamlit as st
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from huggingface_hub import Repository
# -------- CONFIG ----------
MODEL_ID = "Neon-AI/Niche"
CHECKPOINT_DIR = "./checkpoints"
HF_TOKEN = st.secrets["HF_TOKEN"] # Put your HF token in Streamlit secrets
st.title("🧠 Niche Trainer with Push to HF")
# ---------- Load model once ----------
# DO NOT load on startup
# tokenizer, model = load_model() <- remove from top
# Instead, load when first used:
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
return tokenizer, model
tokenizer, model = load_model()
# ---------- LoRA / Full model selection ----------
finetune_type = st.radio("Select fine-tune type:", ["Full model", "LoRA"])
# ---------- JSON input ----------
st.subheader("Paste your JSON training examples")
json_input = st.text_area(
"JSON format: [{'prompt': 'Hello', 'response': 'Hi there!'}, ...]",
height=300,
placeholder='[{"prompt": "...", "response": "..."}]'
)
# ---------- Max token length ----------
max_len = st.slider("Max token length", min_value=64, max_value=512, value=256)
# ---------- Train ----------
train_started = False
if st.button("Train"):
try:
examples = json.loads(json_input)
if not examples:
st.warning("No examples provided!")
else:
texts = [
f"### User:\n{e['prompt']}\n\n### Assistant:\n{e['response']}"
for e in examples
]
ds = Dataset.from_dict({"text": texts})
def tokenize(batch):
out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_len)
out["labels"] = out["input_ids"].copy()
return out
ds = ds.map(tokenize, batched=True)
ds.set_format("torch")
# ---------- Apply LoRA if selected ----------
if finetune_type == "LoRA":
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["c_attn"]
)
train_model = get_peft_model(model, peft_config)
else:
train_model = model
args = TrainingArguments(
output_dir=CHECKPOINT_DIR,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
num_train_epochs=1,
learning_rate=2e-5,
logging_steps=1,
save_strategy="no",
report_to="none",
)
trainer = Trainer(
model=train_model,
args=args,
train_dataset=ds
)
st.info("Training started...")
trainer.train()
st.success("✅ Training done!")
train_started = True
# Use trained model for chat
model = train_model
except Exception as e:
st.error(f"Error during training: {e}")
# ---------- Push to HF ----------
if train_started and st.button("Push to Hugging Face"):
try:
# Prepare repo
if os.path.exists(CHECKPOINT_DIR):
repo = Repository(local_dir=CHECKPOINT_DIR, use_auth_token=HF_TOKEN)
else:
repo = Repository(local_dir=CHECKPOINT_DIR, clone_from=MODEL_ID, use_auth_token=HF_TOKEN)
# Save trained model + tokenizer
model.save_pretrained(CHECKPOINT_DIR)
tokenizer.save_pretrained(CHECKPOINT_DIR)
# Push
repo.push_to_hub(commit_message="Update Niche model with new training")
st.success("✅ Model pushed to HF successfully!")
except Exception as e:
st.error(f"Push failed: {e}")
# ---------- Chat ----------
st.subheader("Test the model")
user_prompt = st.text_input("You:", "")
if st.button("Send"):
if user_prompt.strip():
inputs = tokenizer(user_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.text_area("Niche:", value=response, height=200)