sourize commited on
Commit
a79070b
Β·
verified Β·
1 Parent(s): b67224f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -28
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import streamlit as st
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
4
  from peft import PeftModel
5
  from supabase import create_client
6
  from sentence_transformers import SentenceTransformer
@@ -18,14 +20,15 @@ def get_embedder():
18
  embedder = get_embedder()
19
 
20
  @st.cache_data(show_spinner=False)
21
- def fetch_mems(query, k=5):
22
- vec = embedder.encode(query).tolist()
23
- return supabase.rpc("match_memories",
24
- {"query_embedding": vec, "match_count": k}
25
- ).execute().data
 
26
 
27
  def add_mem(speaker, text):
28
- vec = embedder.encode(text).tolist()
29
  supabase.table("memories").insert({
30
  "speaker": speaker, "text": text, "embedding": vec
31
  }).execute()
@@ -35,34 +38,53 @@ def add_mem(speaker, text):
35
  def load_generator():
36
  REPO = "sourize/phi2-memory-lora"
37
  # 1) Tokenizer
38
- tokenizer = AutoTokenizer.from_pretrained(REPO, trust_remote_code=True, padding_side="left")
 
 
39
  if tokenizer.pad_token_id is None:
40
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
41
- # 2) Base model & resize
 
 
 
 
 
 
 
42
  base = AutoModelForCausalLM.from_pretrained(
43
- "microsoft/phi-2", trust_remote_code=True, torch_dtype="auto"
 
 
 
44
  )
45
  base.resize_token_embeddings(len(tokenizer))
46
- # 3) Overlay LoRA adapter
47
- model = PeftModel.from_pretrained(
48
- base, REPO, device_map="auto", torch_dtype="auto"
49
- )
50
  model.eval()
51
- # 4) Pipeline (greedy, small output for speed)
52
  gen = pipeline(
53
  "text-generation",
54
  model=model,
55
  tokenizer=tokenizer,
56
  device_map="auto",
57
- max_new_tokens=64,
58
- do_sample=False, # greedy decoding
 
 
59
  use_cache=True,
60
- return_full_text=False,
61
  )
62
  return tokenizer, gen
63
 
64
  tokenizer, generator = load_generator()
65
 
 
 
 
 
 
 
 
66
  # ── Streamlit UI ──────────────────────────────────────────────────────────
67
  st.set_page_config(layout="wide")
68
  st.title("🧠 Memory-Aware Phi-2 Chat")
@@ -70,7 +92,7 @@ st.title("🧠 Memory-Aware Phi-2 Chat")
70
  if "history" not in st.session_state:
71
  st.session_state.history = [] # list of (role, message)
72
 
73
- # Render all previous messages as chat bubbles
74
  for role, msg in st.session_state.history:
75
  if role == "You":
76
  st.chat_message("user").write(msg)
@@ -81,21 +103,31 @@ for role, msg in st.session_state.history:
81
  user_input = st.chat_input("Type your message...")
82
 
83
  if user_input:
84
- # 1) show user bubble
85
  st.session_state.history.append(("You", user_input))
86
-
87
- # 2) store user turn
88
  add_mem("user", user_input)
89
 
90
- # 3) retrieve memories and build prompt
91
  mems = fetch_mems(user_input, k=3)
92
  mem_block = "\n".join(f"{m['speaker']}: {m['text']}" for m in mems)
93
- prompt = f"Memory:\n{mem_block}\n\nUser: {user_input}\nAssistant:"
94
 
95
- # 4) generate reply with spinner
 
 
 
 
 
 
 
 
 
96
  with st.spinner("Thinking..."):
97
- out = generator(prompt)[0]["generated_text"].strip()
 
 
 
 
98
 
99
- # 5) show bot bubble and record
100
  st.session_state.history.append(("Bot", out))
101
- add_mem("assistant", out)
 
1
  import os
2
  import streamlit as st
3
+ from transformers import (
4
+ pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+ )
6
  from peft import PeftModel
7
  from supabase import create_client
8
  from sentence_transformers import SentenceTransformer
 
20
  embedder = get_embedder()
21
 
22
  @st.cache_data(show_spinner=False)
23
+ def fetch_mems(query, k=3):
24
+ vec = embedder.encode(query).astype('float32').tolist()
25
+ return supabase.rpc(
26
+ "match_memories",
27
+ {"query_embedding": vec, "match_count": k}
28
+ ).execute().data
29
 
30
  def add_mem(speaker, text):
31
+ vec = embedder.encode(text).astype('float32').tolist()
32
  supabase.table("memories").insert({
33
  "speaker": speaker, "text": text, "embedding": vec
34
  }).execute()
 
38
  def load_generator():
39
  REPO = "sourize/phi2-memory-lora"
40
  # 1) Tokenizer
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ REPO, trust_remote_code=True, padding_side="left"
43
+ )
44
  if tokenizer.pad_token_id is None:
45
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
46
+ # 2) Quantization config for 4-bit
47
+ bnb_config = BitsAndBytesConfig(
48
+ load_in_4bit=True,
49
+ bnb_4bit_quant_type="nf4",
50
+ bnb_4bit_compute_dtype="float16",
51
+ low_cpu_mem_usage=True,
52
+ )
53
+ # 3) Load base model in 4-bit + resize embeddings
54
  base = AutoModelForCausalLM.from_pretrained(
55
+ "microsoft/phi-2",
56
+ trust_remote_code=True,
57
+ quantization_config=bnb_config,
58
+ device_map="auto"
59
  )
60
  base.resize_token_embeddings(len(tokenizer))
61
+ # 4) Overlay LoRA adapter
62
+ model = PeftModel.from_pretrained(base, REPO, device_map="auto", torch_dtype="auto")
 
 
63
  model.eval()
64
+ # 5) Pipeline with greedy sampling + constraints
65
  gen = pipeline(
66
  "text-generation",
67
  model=model,
68
  tokenizer=tokenizer,
69
  device_map="auto",
70
+ max_new_tokens=32,
71
+ do_sample=True,
72
+ temperature=0.2,
73
+ top_p=0.8,
74
  use_cache=True,
75
+ return_full_text=False
76
  )
77
  return tokenizer, gen
78
 
79
  tokenizer, generator = load_generator()
80
 
81
+ # ── System prompt to reduce hallucinations ──────────────────────────────────
82
+ SYSTEM = (
83
+ "You are a helpful assistant.\\n"
84
+ "Answer **only** using the information in the memory below.\\n"
85
+ "If the answer is not in memory, reply: \"I don't know.\"\\n"
86
+ )
87
+
88
  # ── Streamlit UI ──────────────────────────────────────────────────────────
89
  st.set_page_config(layout="wide")
90
  st.title("🧠 Memory-Aware Phi-2 Chat")
 
92
  if "history" not in st.session_state:
93
  st.session_state.history = [] # list of (role, message)
94
 
95
+ # Render existing chat history
96
  for role, msg in st.session_state.history:
97
  if role == "You":
98
  st.chat_message("user").write(msg)
 
103
  user_input = st.chat_input("Type your message...")
104
 
105
  if user_input:
106
+ # Append user message
107
  st.session_state.history.append(("You", user_input))
 
 
108
  add_mem("user", user_input)
109
 
110
+ # Retrieve relevant memories
111
  mems = fetch_mems(user_input, k=3)
112
  mem_block = "\n".join(f"{m['speaker']}: {m['text']}" for m in mems)
 
113
 
114
+ # Build prompt
115
+ prompt = f"""{SYSTEM}
116
+
117
+ Memory:
118
+ {mem_block}
119
+
120
+ User: {user_input}
121
+ Assistant:"""
122
+
123
+ # Generate reply synchronously with spinner
124
  with st.spinner("Thinking..."):
125
+ try:
126
+ out = generator(prompt)[0]["generated_text"].strip()
127
+ except Exception as e:
128
+ out = "Sorry, I encountered an error."
129
+ st.error(f"Generation error: {e}")
130
 
131
+ # Append assistant reply
132
  st.session_state.history.append(("Bot", out))
133
+ add_mem("assistant", out)