Neon-AI commited on
Commit
ce032b0
verified
1 Parent(s): 9f25d4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -15
app.py CHANGED
@@ -1,28 +1,58 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
 
 
4
 
5
- st.title("Niche AI Chat")
6
- st.write("Chat with your trained Niche model from Hugging Face!")
 
 
7
 
8
- # Load model once
9
  @st.cache_resource
10
  def load_model():
11
- model_name = "Neon-AI/Niche" # Your HF repo
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
 
 
 
 
 
 
 
 
14
  return tokenizer, model
15
 
16
  tokenizer, model = load_model()
17
 
18
- # User input
19
- prompt = st.text_input("You:", "")
 
 
 
20
 
21
  if st.button("Send"):
22
- if prompt.strip() == "":
23
- st.warning("Please type something!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  else:
25
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
26
- outputs = model.generate(**inputs, max_new_tokens=100, do_sample=True, temperature=0.7)
27
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- st.text_area("Niche:", value=response, height=200)
 
1
  import streamlit as st
 
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ st.set_page_config(page_title="Niche AI", layout="centered")
6
 
7
+ st.title("馃 Niche AI (CPU Test)")
8
+ st.caption("HF Free Space 2B params slow but real")
9
+
10
+ MODEL_ID = "Neon-AI/Niche"
11
 
 
12
  @st.cache_resource
13
  def load_model():
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ MODEL_ID,
16
+ trust_remote_code=True
17
+ )
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_ID,
20
+ torch_dtype=torch.float32,
21
+ device_map=None # 馃憟 IMPORTANT
22
+ )
23
+ model.to("cpu")
24
+ model.eval()
25
  return tokenizer, model
26
 
27
  tokenizer, model = load_model()
28
 
29
+ # Session chat history
30
+ if "history" not in st.session_state:
31
+ st.session_state.history = []
32
+
33
+ prompt = st.text_input("You", placeholder="Say something...")
34
 
35
  if st.button("Send"):
36
+ if prompt.strip():
37
+ st.session_state.history.append(("You", prompt))
38
+
39
+ inputs = tokenizer(prompt, return_tensors="pt")
40
+
41
+ with torch.no_grad():
42
+ output = model.generate(
43
+ **inputs,
44
+ max_new_tokens=64, # keep it sane on CPU
45
+ do_sample=True,
46
+ temperature=0.8,
47
+ top_p=0.95
48
+ )
49
+
50
+ reply = tokenizer.decode(output[0], skip_special_tokens=True)
51
+ st.session_state.history.append(("Niche", reply))
52
+
53
+ # Display chat
54
+ for speaker, text in st.session_state.history:
55
+ if speaker == "You":
56
+ st.markdown(f"**You:** {text}")
57
  else:
58
+ st.markdown(f"**Niche:** {text}")