Neon-AI commited on
Commit
20b4e3c
·
verified ·
1 Parent(s): cab4035

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -1,20 +1,21 @@
1
- import streamlit as st
2
  import json
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
5
  from datasets import Dataset
 
6
  from peft import LoraConfig, get_peft_model
7
- from huggingface_hub import HfApi, HfFolder, Repository
8
 
9
  # -------- CONFIG ----------
10
  MODEL_ID = "Neon-AI/Niche"
11
  CHECKPOINT_DIR = "./checkpoints"
12
- HF_TOKEN = st.secrets["HF_TOKEN"]
13
 
14
  st.title("🧠 Niche Trainer with Push to HF")
15
 
16
  # ---------- Load model once ----------
17
- @st.cache_resource
18
  def load_model():
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
  model = AutoModelForCausalLM.from_pretrained(
@@ -39,6 +40,9 @@ json_input = st.text_area(
39
  placeholder='[{"prompt": "...", "response": "..."}]'
40
  )
41
 
 
 
 
42
  # ---------- Train ----------
43
  train_started = False
44
  if st.button("Train"):
@@ -54,7 +58,7 @@ if st.button("Train"):
54
  ds = Dataset.from_dict({"text": texts})
55
 
56
  def tokenize(batch):
57
- out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)
58
  out["labels"] = out["input_ids"].copy()
59
  return out
60
 
@@ -70,8 +74,7 @@ if st.button("Train"):
70
  lora_dropout=0.1,
71
  target_modules=["c_attn"]
72
  )
73
- model_peft = get_peft_model(model, peft_config)
74
- train_model = model_peft
75
  else:
76
  train_model = model
77
 
@@ -96,23 +99,30 @@ if st.button("Train"):
96
  trainer.train()
97
  st.success("✅ Training done!")
98
  train_started = True
 
 
 
 
99
  except Exception as e:
100
- st.error(f"Error: {e}")
101
 
102
  # ---------- Push to HF ----------
103
  if train_started and st.button("Push to Hugging Face"):
104
  try:
105
- repo = Repository(
106
- local_dir=CHECKPOINT_DIR,
107
- clone_from=MODEL_ID,
108
- use_auth_token=HF_TOKEN
109
- )
 
110
  # Save trained model + tokenizer
111
- train_model.save_pretrained(CHECKPOINT_DIR)
112
  tokenizer.save_pretrained(CHECKPOINT_DIR)
113
 
 
114
  repo.push_to_hub(commit_message="Update Niche model with new training")
115
  st.success("✅ Model pushed to HF successfully!")
 
116
  except Exception as e:
117
  st.error(f"Push failed: {e}")
118
 
 
1
+ import os
2
  import json
3
  import torch
4
+ import streamlit as st
5
  from datasets import Dataset
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
7
  from peft import LoraConfig, get_peft_model
8
+ from huggingface_hub import Repository
9
 
10
  # -------- CONFIG ----------
11
  MODEL_ID = "Neon-AI/Niche"
12
  CHECKPOINT_DIR = "./checkpoints"
13
+ HF_TOKEN = st.secrets["HF_TOKEN"] # Put your HF token in Streamlit secrets
14
 
15
  st.title("🧠 Niche Trainer with Push to HF")
16
 
17
  # ---------- Load model once ----------
18
+ @st.cache_resource(show_spinner=True)
19
  def load_model():
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
  model = AutoModelForCausalLM.from_pretrained(
 
40
  placeholder='[{"prompt": "...", "response": "..."}]'
41
  )
42
 
43
+ # ---------- Max token length ----------
44
+ max_len = st.slider("Max token length", min_value=64, max_value=512, value=256)
45
+
46
  # ---------- Train ----------
47
  train_started = False
48
  if st.button("Train"):
 
58
  ds = Dataset.from_dict({"text": texts})
59
 
60
  def tokenize(batch):
61
+ out = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_len)
62
  out["labels"] = out["input_ids"].copy()
63
  return out
64
 
 
74
  lora_dropout=0.1,
75
  target_modules=["c_attn"]
76
  )
77
+ train_model = get_peft_model(model, peft_config)
 
78
  else:
79
  train_model = model
80
 
 
99
  trainer.train()
100
  st.success("✅ Training done!")
101
  train_started = True
102
+
103
+ # Use trained model for chat
104
+ model = train_model
105
+
106
  except Exception as e:
107
+ st.error(f"Error during training: {e}")
108
 
109
  # ---------- Push to HF ----------
110
  if train_started and st.button("Push to Hugging Face"):
111
  try:
112
+ # Prepare repo
113
+ if os.path.exists(CHECKPOINT_DIR):
114
+ repo = Repository(local_dir=CHECKPOINT_DIR, use_auth_token=HF_TOKEN)
115
+ else:
116
+ repo = Repository(local_dir=CHECKPOINT_DIR, clone_from=MODEL_ID, use_auth_token=HF_TOKEN)
117
+
118
  # Save trained model + tokenizer
119
+ model.save_pretrained(CHECKPOINT_DIR)
120
  tokenizer.save_pretrained(CHECKPOINT_DIR)
121
 
122
+ # Push
123
  repo.push_to_hub(commit_message="Update Niche model with new training")
124
  st.success("✅ Model pushed to HF successfully!")
125
+
126
  except Exception as e:
127
  st.error(f"Push failed: {e}")
128