import os import json import torch import streamlit as st from datasets import Dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments from peft import LoraConfig, get_peft_model from huggingface_hub import Repository # -------- CONFIG ---------- MODEL_ID = "Neon-AI/Niche" CHECKPOINT_DIR = "./checkpoints" HF_TOKEN = st.secrets["HF_TOKEN"] # Put your HF token in Streamlit secrets st.title("🧠 Niche Trainer with Push to HF") # ---------- Load model once ---------- # DO NOT load on startup # tokenizer, model = load_model() <- remove from top # Instead, load when first used: @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.resize_token_embeddings(len(tokenizer)) return tokenizer, model tokenizer, model = load_model() # ---------- LoRA / Full model selection ---------- finetune_type = st.radio("Select fine-tune type:", ["Full model", "LoRA"]) # ---------- JSON input ---------- st.subheader("Paste your JSON training examples") json_input = st.text_area( "JSON format: [{'prompt': 'Hello', 'response': 'Hi there!'}, ...]", height=300, placeholder='[{"prompt": "...", "response": "..."}]' ) # ---------- Max token length ---------- max_len = st.slider("Max token length", min_value=64, max_value=512, value=256) # ---------- Train ---------- train_started = False if st.button("Train"): try: examples = json.loads(json_input) if not examples: st.warning("No examples provided!") else: texts = [ f"### User:\n{e['prompt']}\n\n### Assistant:\n{e['response']}" for e in examples ] ds = Dataset.from_dict({"text": texts}) def tokenize(batch): out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_len) out["labels"] = out["input_ids"].copy() return out ds = ds.map(tokenize, batched=True) ds.set_format("torch") # ---------- Apply LoRA if selected ---------- if finetune_type == "LoRA": peft_config = LoraConfig( task_type="CAUSAL_LM", r=16, lora_alpha=32, lora_dropout=0.1, target_modules=["c_attn"] ) train_model = get_peft_model(model, peft_config) else: train_model = model args = TrainingArguments( output_dir=CHECKPOINT_DIR, per_device_train_batch_size=1, gradient_accumulation_steps=2, num_train_epochs=1, learning_rate=2e-5, logging_steps=1, save_strategy="no", report_to="none", ) trainer = Trainer( model=train_model, args=args, train_dataset=ds ) st.info("Training started...") trainer.train() st.success("✅ Training done!") train_started = True # Use trained model for chat model = train_model except Exception as e: st.error(f"Error during training: {e}") # ---------- Push to HF ---------- if train_started and st.button("Push to Hugging Face"): try: # Prepare repo if os.path.exists(CHECKPOINT_DIR): repo = Repository(local_dir=CHECKPOINT_DIR, use_auth_token=HF_TOKEN) else: repo = Repository(local_dir=CHECKPOINT_DIR, clone_from=MODEL_ID, use_auth_token=HF_TOKEN) # Save trained model + tokenizer model.save_pretrained(CHECKPOINT_DIR) tokenizer.save_pretrained(CHECKPOINT_DIR) # Push repo.push_to_hub(commit_message="Update Niche model with new training") st.success("✅ Model pushed to HF successfully!") except Exception as e: st.error(f"Push failed: {e}") # ---------- Chat ---------- st.subheader("Test the model") user_prompt = st.text_input("You:", "") if st.button("Send"): if user_prompt.strip(): inputs = tokenizer(user_prompt, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7) response = tokenizer.decode(outputs[0], skip_special_tokens=True) st.text_area("Niche:", value=response, height=200)