Spaces:
Paused
Paused
| import os | |
| import json | |
| import torch | |
| import streamlit as st | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
| from peft import LoraConfig, get_peft_model | |
| from huggingface_hub import Repository | |
| # -------- CONFIG ---------- | |
| MODEL_ID = "Neon-AI/Niche" | |
| CHECKPOINT_DIR = "./checkpoints" | |
| HF_TOKEN = st.secrets["HF_TOKEN"] # Put your HF token in Streamlit secrets | |
| st.title("🧠 Niche Trainer with Push to HF") | |
| # ---------- Load model once ---------- | |
| # DO NOT load on startup | |
| # tokenizer, model = load_model() <- remove from top | |
| # Instead, load when first used: | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.resize_token_embeddings(len(tokenizer)) | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| # ---------- LoRA / Full model selection ---------- | |
| finetune_type = st.radio("Select fine-tune type:", ["Full model", "LoRA"]) | |
| # ---------- JSON input ---------- | |
| st.subheader("Paste your JSON training examples") | |
| json_input = st.text_area( | |
| "JSON format: [{'prompt': 'Hello', 'response': 'Hi there!'}, ...]", | |
| height=300, | |
| placeholder='[{"prompt": "...", "response": "..."}]' | |
| ) | |
| # ---------- Max token length ---------- | |
| max_len = st.slider("Max token length", min_value=64, max_value=512, value=256) | |
| # ---------- Train ---------- | |
| train_started = False | |
| if st.button("Train"): | |
| try: | |
| examples = json.loads(json_input) | |
| if not examples: | |
| st.warning("No examples provided!") | |
| else: | |
| texts = [ | |
| f"### User:\n{e['prompt']}\n\n### Assistant:\n{e['response']}" | |
| for e in examples | |
| ] | |
| ds = Dataset.from_dict({"text": texts}) | |
| def tokenize(batch): | |
| out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_len) | |
| out["labels"] = out["input_ids"].copy() | |
| return out | |
| ds = ds.map(tokenize, batched=True) | |
| ds.set_format("torch") | |
| # ---------- Apply LoRA if selected ---------- | |
| if finetune_type == "LoRA": | |
| peft_config = LoraConfig( | |
| task_type="CAUSAL_LM", | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| target_modules=["c_attn"] | |
| ) | |
| train_model = get_peft_model(model, peft_config) | |
| else: | |
| train_model = model | |
| args = TrainingArguments( | |
| output_dir=CHECKPOINT_DIR, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=2, | |
| num_train_epochs=1, | |
| learning_rate=2e-5, | |
| logging_steps=1, | |
| save_strategy="no", | |
| report_to="none", | |
| ) | |
| trainer = Trainer( | |
| model=train_model, | |
| args=args, | |
| train_dataset=ds | |
| ) | |
| st.info("Training started...") | |
| trainer.train() | |
| st.success("✅ Training done!") | |
| train_started = True | |
| # Use trained model for chat | |
| model = train_model | |
| except Exception as e: | |
| st.error(f"Error during training: {e}") | |
| # ---------- Push to HF ---------- | |
| if train_started and st.button("Push to Hugging Face"): | |
| try: | |
| # Prepare repo | |
| if os.path.exists(CHECKPOINT_DIR): | |
| repo = Repository(local_dir=CHECKPOINT_DIR, use_auth_token=HF_TOKEN) | |
| else: | |
| repo = Repository(local_dir=CHECKPOINT_DIR, clone_from=MODEL_ID, use_auth_token=HF_TOKEN) | |
| # Save trained model + tokenizer | |
| model.save_pretrained(CHECKPOINT_DIR) | |
| tokenizer.save_pretrained(CHECKPOINT_DIR) | |
| # Push | |
| repo.push_to_hub(commit_message="Update Niche model with new training") | |
| st.success("✅ Model pushed to HF successfully!") | |
| except Exception as e: | |
| st.error(f"Push failed: {e}") | |
| # ---------- Chat ---------- | |
| st.subheader("Test the model") | |
| user_prompt = st.text_input("You:", "") | |
| if st.button("Send"): | |
| if user_prompt.strip(): | |
| inputs = tokenizer(user_prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| st.text_area("Niche:", value=response, height=200) |