sourize commited on
Commit
c283634
Β·
verified Β·
1 Parent(s): 70fd1ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -81
app.py CHANGED
@@ -1,105 +1,113 @@
1
  import os
2
  import streamlit as st
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
4
  from peft import PeftModel
5
  from supabase import create_client
6
  from sentence_transformers import SentenceTransformer
 
7
 
8
- # β€” Supabase creds from Secrets β€”
9
  SUPA_URL = os.getenv("SUPABASE_URL")
10
  SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
11
  supabase = create_client(SUPA_URL, SUPA_KEY)
12
 
13
- # β€” Embedding model & retrieval function β€”
14
- embedder = SentenceTransformer("paraphrase-MiniLM-L3-v2")
 
 
 
 
 
 
15
  def fetch_mems(query, k=5):
16
  vec = embedder.encode(query).tolist()
17
- data = supabase.rpc(
18
- "match_memories",
19
- {"query_embedding": vec, "match_count": k}
20
- ).execute().data
21
- return data
22
 
23
  def add_mem(speaker, text):
24
  vec = embedder.encode(text).tolist()
25
  supabase.table("memories").insert({
26
- "speaker": speaker,
27
- "text": text,
28
- "embedding": vec
29
  }).execute()
30
 
31
- # β€” Load tokenizer & adapter from HF hub β€”
32
- REPO = "sourize/phi2-memory-lora"
33
-
34
- # 1) Tokenizer (with your extra PAD token)
35
- tokenizer = AutoTokenizer.from_pretrained(
36
- REPO, trust_remote_code=True, padding_side="left"
37
- )
38
- if tokenizer.pad_token_id is None:
39
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
40
-
41
- # 2) Base Phi-2 β†’ resize embeddings to match tokenizer
42
- base = AutoModelForCausalLM.from_pretrained(
43
- "microsoft/phi-2", trust_remote_code=True, torch_dtype="auto"
44
- )
45
- base.resize_token_embeddings(len(tokenizer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # 3) Overlay your LoRA adapter
48
- model = PeftModel.from_pretrained(
49
- base,
50
- REPO,
51
- torch_dtype="auto",
52
- device_map="auto" # let accelerate pick CPU/GPU
53
- )
54
- model.eval()
55
-
56
- # 4) Build the generation pipeline
57
- pipe = pipeline(
58
- "text-generation",
59
- model=model,
60
- tokenizer=tokenizer,
61
- device=0, # or device_map="auto"
62
- do_sample=True,
63
- top_p=0.9,
64
- temperature=0.8,
65
- )
66
-
67
- # β€” Streamlit UI β€”
68
- st.title("🧠 Memory-Aware Phi-2 Bot")
69
  if "history" not in st.session_state:
70
- st.session_state.history = []
71
-
72
- def chat(u: str) -> str:
73
- # store user turn
74
- add_mem("user", u)
75
-
76
- # fetch & format memories
77
- mems = fetch_mems(u, 3)
78
- block = "\n".join(f"{m['speaker']}: {m['text']}" for m in mems)
79
-
80
- # build prompt
81
- prompt = f"""Memory:
82
- {block}
83
-
84
- User: {u}
85
- Assistant:"""
86
-
87
- # generate reply
88
- out = pipe(prompt, max_length=200)[0]["generated_text"]
89
- reply = out.split("Assistant:")[-1].strip()
90
-
91
- # store assistant turn
 
 
 
 
92
  add_mem("assistant", reply)
93
  return reply
94
 
95
- user = st.text_input("You:")
96
- if user:
97
- resp = chat(user)
98
- st.session_state.history.append(("You", user))
99
- st.session_state.history.append(("Bot", resp))
100
-
101
- for speaker, text in st.session_state.history:
102
- if speaker == "You":
103
- st.markdown(f"**You:** {text}")
104
  else:
105
- st.markdown(f"**Assistant:** {text}")
 
 
 
 
 
 
 
 
 
1
  import os
2
  import streamlit as st
3
+ from transformers import (
4
+ pipeline,
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ TextIteratorStreamer
8
+ )
9
  from peft import PeftModel
10
  from supabase import create_client
11
  from sentence_transformers import SentenceTransformer
12
+ import threading
13
 
14
+ # ── 1) Supabase setup ───────────────────────────────────────────────────────
15
  SUPA_URL = os.getenv("SUPABASE_URL")
16
  SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
17
  supabase = create_client(SUPA_URL, SUPA_KEY)
18
 
19
+ # ── 2) Embedder & memory RPC ────────────────────────────────────────────────
20
+ @st.cache_resource(show_spinner=False)
21
+ def get_embedder():
22
+ return SentenceTransformer("paraphrase-MiniLM-L3-v2")
23
+
24
+ embedder = get_embedder()
25
+
26
+ @st.cache_data(show_spinner=False)
27
  def fetch_mems(query, k=5):
28
  vec = embedder.encode(query).tolist()
29
+ return supabase.rpc("match_memories", {"query_embedding": vec, "match_count": k}).execute().data
 
 
 
 
30
 
31
  def add_mem(speaker, text):
32
  vec = embedder.encode(text).tolist()
33
  supabase.table("memories").insert({
34
+ "speaker": speaker, "text": text, "embedding": vec
 
 
35
  }).execute()
36
 
37
+ # ── 3) Model + tokenizer (cached) ───────────────────────────────────────────
38
+ @st.cache_resource(show_spinner=False)
39
+ def load_model():
40
+ REPO = "sourize/phi2-memory-lora"
41
+ # tokenizer
42
+ tok = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True, padding_side="left")
43
+ if tok.pad_token_id is None:
44
+ tok.add_special_tokens({"pad_token": "[PAD]"})
45
+ # base + resize
46
+ base = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
47
+ base.resize_token_embeddings(len(tok))
48
+ # adapter overlay
49
+ model = PeftModel.from_pretrained(base, REPO, device_map="auto", torch_dtype="auto")
50
+ model.eval()
51
+ # prepare a streaming pipeline
52
+ return tok, pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tok,
56
+ device_map="auto",
57
+ max_new_tokens=64,
58
+ do_sample=False,
59
+ use_cache=True,
60
+ return_full_text=False,
61
+ streamer=TextIteratorStreamer # enable streaming
62
+ )
63
+
64
+ tokenizer, generator = load_model()
65
+
66
+ # ── 4) Streamlit UI setup ───────────────────────────────────────────────────
67
+ st.set_page_config(layout="wide")
68
+ st.title("🧠 Memory-Aware Phi-2 Chat")
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if "history" not in st.session_state:
71
+ st.session_state.history = [] # list of (role, message)
72
+
73
+ # ── 5) Chat function ────────────────────────────────────────────────────────
74
+ def chat(user_input: str):
75
+ add_mem("user", user_input)
76
+ # retrieve top-3 memories
77
+ mems = fetch_mems(user_input, k=3)
78
+ mem_block = "\n".join(f"{m['speaker']}: {m['text']}" for m in mems)
79
+ prompt = f"Memory:\n{mem_block}\n\nUser: {user_input}\nAssistant:"
80
+
81
+ # stream generation
82
+ streamer = generator.tokenizer.streamer if hasattr(generator.tokenizer, "streamer") else None
83
+ if streamer:
84
+ # If using TextIteratorStreamer, kick off async thread
85
+ thread = threading.Thread(target=generator, kwargs={"prompt": prompt})
86
+ thread.start()
87
+ output = ""
88
+ for token in streamer:
89
+ output += token
90
+ # update the last message in session_state so UI refreshes
91
+ st.session_state.history[-1] = ("Bot", output)
92
+ st.experimental_rerun()
93
+ thread.join()
94
+ else:
95
+ output = generator(prompt)[0]["generated_text"]
96
+ reply = output.strip()
97
  add_mem("assistant", reply)
98
  return reply
99
 
100
+ # ── 6) Render chat bubbles & input ──────────────────────────────────────────
101
+ for role, msg in st.session_state.history:
102
+ if role == "You":
103
+ st.chat_message("user").write(msg)
 
 
 
 
 
104
  else:
105
+ st.chat_message("assistant").write(msg)
106
+
107
+ user_input = st.chat_input("Type your message...")
108
+ if user_input:
109
+ # append placeholder so streamer can fill it
110
+ st.session_state.history.append(("You", user_input))
111
+ st.session_state.history.append(("Bot", ""))
112
+ # run chat (which will update the last bubble)
113
+ chat(user_input)