Spaces:
Paused
Paused
File size: 4,762 Bytes
20b4e3c cab4035 20b4e3c cab4035 20b4e3c cab4035 20b4e3c 8020b6e cab4035 20b4e3c 8020b6e cab4035 c2ec7cc cab4035 20b4e3c cab4035 20b4e3c cab4035 20b4e3c cab4035 20b4e3c cab4035 20b4e3c cab4035 8020b6e 20b4e3c cab4035 20b4e3c cab4035 8020b6e 20b4e3c cab4035 20b4e3c cab4035 8020b6e cab4035 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import os
import json
import torch
import streamlit as st
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from huggingface_hub import Repository
# -------- CONFIG ----------
MODEL_ID = "Neon-AI/Niche"
CHECKPOINT_DIR = "./checkpoints"
HF_TOKEN = st.secrets["HF_TOKEN"] # Put your HF token in Streamlit secrets
st.title("🧠 Niche Trainer with Push to HF")
# ---------- Load model once ----------
# DO NOT load on startup
# tokenizer, model = load_model() <- remove from top
# Instead, load when first used:
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))
return tokenizer, model
tokenizer, model = load_model()
# ---------- LoRA / Full model selection ----------
finetune_type = st.radio("Select fine-tune type:", ["Full model", "LoRA"])
# ---------- JSON input ----------
st.subheader("Paste your JSON training examples")
json_input = st.text_area(
"JSON format: [{'prompt': 'Hello', 'response': 'Hi there!'}, ...]",
height=300,
placeholder='[{"prompt": "...", "response": "..."}]'
)
# ---------- Max token length ----------
max_len = st.slider("Max token length", min_value=64, max_value=512, value=256)
# ---------- Train ----------
train_started = False
if st.button("Train"):
try:
examples = json.loads(json_input)
if not examples:
st.warning("No examples provided!")
else:
texts = [
f"### User:\n{e['prompt']}\n\n### Assistant:\n{e['response']}"
for e in examples
]
ds = Dataset.from_dict({"text": texts})
def tokenize(batch):
out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_len)
out["labels"] = out["input_ids"].copy()
return out
ds = ds.map(tokenize, batched=True)
ds.set_format("torch")
# ---------- Apply LoRA if selected ----------
if finetune_type == "LoRA":
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["c_attn"]
)
train_model = get_peft_model(model, peft_config)
else:
train_model = model
args = TrainingArguments(
output_dir=CHECKPOINT_DIR,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
num_train_epochs=1,
learning_rate=2e-5,
logging_steps=1,
save_strategy="no",
report_to="none",
)
trainer = Trainer(
model=train_model,
args=args,
train_dataset=ds
)
st.info("Training started...")
trainer.train()
st.success("✅ Training done!")
train_started = True
# Use trained model for chat
model = train_model
except Exception as e:
st.error(f"Error during training: {e}")
# ---------- Push to HF ----------
if train_started and st.button("Push to Hugging Face"):
try:
# Prepare repo
if os.path.exists(CHECKPOINT_DIR):
repo = Repository(local_dir=CHECKPOINT_DIR, use_auth_token=HF_TOKEN)
else:
repo = Repository(local_dir=CHECKPOINT_DIR, clone_from=MODEL_ID, use_auth_token=HF_TOKEN)
# Save trained model + tokenizer
model.save_pretrained(CHECKPOINT_DIR)
tokenizer.save_pretrained(CHECKPOINT_DIR)
# Push
repo.push_to_hub(commit_message="Update Niche model with new training")
st.success("✅ Model pushed to HF successfully!")
except Exception as e:
st.error(f"Push failed: {e}")
# ---------- Chat ----------
st.subheader("Test the model")
user_prompt = st.text_input("You:", "")
if st.button("Send"):
if user_prompt.strip():
inputs = tokenizer(user_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.text_area("Niche:", value=response, height=200) |