File size: 4,762 Bytes
20b4e3c
cab4035
 
20b4e3c
cab4035
20b4e3c
cab4035
20b4e3c
8020b6e
cab4035
 
 
20b4e3c
8020b6e
cab4035
 
 
c2ec7cc
 
 
 
 
cab4035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b4e3c
 
 
cab4035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b4e3c
cab4035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b4e3c
cab4035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b4e3c
 
 
 
cab4035
20b4e3c
cab4035
 
 
8020b6e
20b4e3c
 
 
 
 
 
cab4035
20b4e3c
cab4035
8020b6e
20b4e3c
cab4035
 
20b4e3c
cab4035
 
8020b6e
cab4035
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import json
import torch
import streamlit as st
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from huggingface_hub import Repository

# -------- CONFIG ----------
MODEL_ID = "Neon-AI/Niche"
CHECKPOINT_DIR = "./checkpoints"
HF_TOKEN = st.secrets["HF_TOKEN"]  # Put your HF token in Streamlit secrets

st.title("🧠 Niche Trainer with Push to HF")

# ---------- Load model once ----------
# DO NOT load on startup
# tokenizer, model = load_model()  <- remove from top

# Instead, load when first used:
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.resize_token_embeddings(len(tokenizer))
    return tokenizer, model

tokenizer, model = load_model()

# ---------- LoRA / Full model selection ----------
finetune_type = st.radio("Select fine-tune type:", ["Full model", "LoRA"])

# ---------- JSON input ----------
st.subheader("Paste your JSON training examples")
json_input = st.text_area(
    "JSON format: [{'prompt': 'Hello', 'response': 'Hi there!'}, ...]",
    height=300,
    placeholder='[{"prompt": "...", "response": "..."}]'
)

# ---------- Max token length ----------
max_len = st.slider("Max token length", min_value=64, max_value=512, value=256)

# ---------- Train ----------
train_started = False
if st.button("Train"):
    try:
        examples = json.loads(json_input)
        if not examples:
            st.warning("No examples provided!")
        else:
            texts = [
                f"### User:\n{e['prompt']}\n\n### Assistant:\n{e['response']}"
                for e in examples
            ]
            ds = Dataset.from_dict({"text": texts})

            def tokenize(batch):
                out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_len)
                out["labels"] = out["input_ids"].copy()
                return out

            ds = ds.map(tokenize, batched=True)
            ds.set_format("torch")

            # ---------- Apply LoRA if selected ----------
            if finetune_type == "LoRA":
                peft_config = LoraConfig(
                    task_type="CAUSAL_LM",
                    r=16,
                    lora_alpha=32,
                    lora_dropout=0.1,
                    target_modules=["c_attn"]
                )
                train_model = get_peft_model(model, peft_config)
            else:
                train_model = model

            args = TrainingArguments(
                output_dir=CHECKPOINT_DIR,
                per_device_train_batch_size=1,
                gradient_accumulation_steps=2,
                num_train_epochs=1,
                learning_rate=2e-5,
                logging_steps=1,
                save_strategy="no",
                report_to="none",
            )

            trainer = Trainer(
                model=train_model,
                args=args,
                train_dataset=ds
            )

            st.info("Training started...")
            trainer.train()
            st.success("✅ Training done!")
            train_started = True

            # Use trained model for chat
            model = train_model

    except Exception as e:
        st.error(f"Error during training: {e}")

# ---------- Push to HF ----------
if train_started and st.button("Push to Hugging Face"):
    try:
        # Prepare repo
        if os.path.exists(CHECKPOINT_DIR):
            repo = Repository(local_dir=CHECKPOINT_DIR, use_auth_token=HF_TOKEN)
        else:
            repo = Repository(local_dir=CHECKPOINT_DIR, clone_from=MODEL_ID, use_auth_token=HF_TOKEN)

        # Save trained model + tokenizer
        model.save_pretrained(CHECKPOINT_DIR)
        tokenizer.save_pretrained(CHECKPOINT_DIR)

        # Push
        repo.push_to_hub(commit_message="Update Niche model with new training")
        st.success("✅ Model pushed to HF successfully!")

    except Exception as e:
        st.error(f"Push failed: {e}")

# ---------- Chat ----------
st.subheader("Test the model")
user_prompt = st.text_input("You:", "")
if st.button("Send"):
    if user_prompt.strip():
        inputs = tokenizer(user_prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        st.text_area("Niche:", value=response, height=200)