sourize commited on
Commit
b67224f
Β·
verified Β·
1 Parent(s): c283634

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -60
app.py CHANGED
@@ -1,22 +1,16 @@
1
  import os
2
  import streamlit as st
3
- from transformers import (
4
- pipeline,
5
- AutoTokenizer,
6
- AutoModelForCausalLM,
7
- TextIteratorStreamer
8
- )
9
  from peft import PeftModel
10
  from supabase import create_client
11
  from sentence_transformers import SentenceTransformer
12
- import threading
13
 
14
- # ── 1) Supabase setup ───────────────────────────────────────────────────────
15
  SUPA_URL = os.getenv("SUPABASE_URL")
16
  SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
17
  supabase = create_client(SUPA_URL, SUPA_KEY)
18
 
19
- # ── 2) Embedder & memory RPC ────────────────────────────────────────────────
20
  @st.cache_resource(show_spinner=False)
21
  def get_embedder():
22
  return SentenceTransformer("paraphrase-MiniLM-L3-v2")
@@ -26,7 +20,9 @@ embedder = get_embedder()
26
  @st.cache_data(show_spinner=False)
27
  def fetch_mems(query, k=5):
28
  vec = embedder.encode(query).tolist()
29
- return supabase.rpc("match_memories", {"query_embedding": vec, "match_count": k}).execute().data
 
 
30
 
31
  def add_mem(speaker, text):
32
  vec = embedder.encode(text).tolist()
@@ -34,80 +30,72 @@ def add_mem(speaker, text):
34
  "speaker": speaker, "text": text, "embedding": vec
35
  }).execute()
36
 
37
- # ── 3) Model + tokenizer (cached) ───────────────────────────────────────────
38
  @st.cache_resource(show_spinner=False)
39
- def load_model():
40
  REPO = "sourize/phi2-memory-lora"
41
- # tokenizer
42
- tok = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True, padding_side="left")
43
- if tok.pad_token_id is None:
44
- tok.add_special_tokens({"pad_token": "[PAD]"})
45
- # base + resize
46
- base = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
47
- base.resize_token_embeddings(len(tok))
48
- # adapter overlay
49
- model = PeftModel.from_pretrained(base, REPO, device_map="auto", torch_dtype="auto")
 
 
 
 
50
  model.eval()
51
- # prepare a streaming pipeline
52
- return tok, pipeline(
53
  "text-generation",
54
  model=model,
55
- tokenizer=tok,
56
  device_map="auto",
57
  max_new_tokens=64,
58
- do_sample=False,
59
  use_cache=True,
60
  return_full_text=False,
61
- streamer=TextIteratorStreamer # enable streaming
62
  )
 
63
 
64
- tokenizer, generator = load_model()
65
 
66
- # ── 4) Streamlit UI setup ───────────────────────────────────────────────────
67
  st.set_page_config(layout="wide")
68
  st.title("🧠 Memory-Aware Phi-2 Chat")
69
 
70
  if "history" not in st.session_state:
71
  st.session_state.history = [] # list of (role, message)
72
 
73
- # ── 5) Chat function ────────────────────────────────────────────────────────
74
- def chat(user_input: str):
75
- add_mem("user", user_input)
76
- # retrieve top-3 memories
77
- mems = fetch_mems(user_input, k=3)
78
- mem_block = "\n".join(f"{m['speaker']}: {m['text']}" for m in mems)
79
- prompt = f"Memory:\n{mem_block}\n\nUser: {user_input}\nAssistant:"
80
-
81
- # stream generation
82
- streamer = generator.tokenizer.streamer if hasattr(generator.tokenizer, "streamer") else None
83
- if streamer:
84
- # If using TextIteratorStreamer, kick off async thread
85
- thread = threading.Thread(target=generator, kwargs={"prompt": prompt})
86
- thread.start()
87
- output = ""
88
- for token in streamer:
89
- output += token
90
- # update the last message in session_state so UI refreshes
91
- st.session_state.history[-1] = ("Bot", output)
92
- st.experimental_rerun()
93
- thread.join()
94
- else:
95
- output = generator(prompt)[0]["generated_text"]
96
- reply = output.strip()
97
- add_mem("assistant", reply)
98
- return reply
99
-
100
- # ── 6) Render chat bubbles & input ──────────────────────────────────────────
101
  for role, msg in st.session_state.history:
102
  if role == "You":
103
  st.chat_message("user").write(msg)
104
  else:
105
  st.chat_message("assistant").write(msg)
106
 
 
107
  user_input = st.chat_input("Type your message...")
 
108
  if user_input:
109
- # append placeholder so streamer can fill it
110
  st.session_state.history.append(("You", user_input))
111
- st.session_state.history.append(("Bot", ""))
112
- # run chat (which will update the last bubble)
113
- chat(user_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import streamlit as st
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
4
  from peft import PeftModel
5
  from supabase import create_client
6
  from sentence_transformers import SentenceTransformer
 
7
 
8
+ # ── Supabase setup ─────────────────────────────────────────────────────────
9
  SUPA_URL = os.getenv("SUPABASE_URL")
10
  SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
11
  supabase = create_client(SUPA_URL, SUPA_KEY)
12
 
13
+ # ── Embedder & memory RPC ──────────────────────────────────────────────────
14
  @st.cache_resource(show_spinner=False)
15
  def get_embedder():
16
  return SentenceTransformer("paraphrase-MiniLM-L3-v2")
 
20
  @st.cache_data(show_spinner=False)
21
  def fetch_mems(query, k=5):
22
  vec = embedder.encode(query).tolist()
23
+ return supabase.rpc("match_memories",
24
+ {"query_embedding": vec, "match_count": k}
25
+ ).execute().data
26
 
27
  def add_mem(speaker, text):
28
  vec = embedder.encode(text).tolist()
 
30
  "speaker": speaker, "text": text, "embedding": vec
31
  }).execute()
32
 
33
+ # ── Model + tokenizer ──────────────────────────────────────────────────────
34
  @st.cache_resource(show_spinner=False)
35
+ def load_generator():
36
  REPO = "sourize/phi2-memory-lora"
37
+ # 1) Tokenizer
38
+ tokenizer = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True, padding_side="left")
39
+ if tokenizer.pad_token_id is None:
40
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
41
+ # 2) Base model & resize
42
+ base = AutoModelForCausalLM.from_pretrained(
43
+ "microsoft/phi-2", trust_remote_code=True, torch_dtype="auto"
44
+ )
45
+ base.resize_token_embeddings(len(tokenizer))
46
+ # 3) Overlay LoRA adapter
47
+ model = PeftModel.from_pretrained(
48
+ base, REPO, device_map="auto", torch_dtype="auto"
49
+ )
50
  model.eval()
51
+ # 4) Pipeline (greedy, small output for speed)
52
+ gen = pipeline(
53
  "text-generation",
54
  model=model,
55
+ tokenizer=tokenizer,
56
  device_map="auto",
57
  max_new_tokens=64,
58
+ do_sample=False, # greedy decoding
59
  use_cache=True,
60
  return_full_text=False,
 
61
  )
62
+ return tokenizer, gen
63
 
64
+ tokenizer, generator = load_generator()
65
 
66
+ # ── Streamlit UI ──────────────────────────────────────────────────────────
67
  st.set_page_config(layout="wide")
68
  st.title("🧠 Memory-Aware Phi-2 Chat")
69
 
70
  if "history" not in st.session_state:
71
  st.session_state.history = [] # list of (role, message)
72
 
73
+ # Render all previous messages as chat bubbles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  for role, msg in st.session_state.history:
75
  if role == "You":
76
  st.chat_message("user").write(msg)
77
  else:
78
  st.chat_message("assistant").write(msg)
79
 
80
+ # Input box at the bottom
81
  user_input = st.chat_input("Type your message...")
82
+
83
  if user_input:
84
+ # 1) show user bubble
85
  st.session_state.history.append(("You", user_input))
86
+
87
+ # 2) store user turn
88
+ add_mem("user", user_input)
89
+
90
+ # 3) retrieve memories and build prompt
91
+ mems = fetch_mems(user_input, k=3)
92
+ mem_block = "\n".join(f"{m['speaker']}: {m['text']}" for m in mems)
93
+ prompt = f"Memory:\n{mem_block}\n\nUser: {user_input}\nAssistant:"
94
+
95
+ # 4) generate reply with spinner
96
+ with st.spinner("Thinking..."):
97
+ out = generator(prompt)[0]["generated_text"].strip()
98
+
99
+ # 5) show bot bubble and record
100
+ st.session_state.history.append(("Bot", out))
101
+ add_mem("assistant", out)