sourize commited on
Commit
d934644
Β·
1 Parent(s): 7fab575
Files changed (1) hide show
  1. app.py +52 -96
app.py CHANGED
@@ -8,108 +8,80 @@ from transformers import (
8
  BitsAndBytesConfig,
9
  )
10
  from peft import LoraConfig, get_peft_model
11
- from supabase import create_client
12
- from sentence_transformers import SentenceTransformer
13
  from safetensors.torch import load_file as safe_load
14
 
15
- # ── Supabase setup ─────────────────────────────────────────────────────────
16
- SUPA_URL = os.getenv("SUPABASE_URL")
17
- SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
18
- supabase = create_client(SUPA_URL, SUPA_KEY)
 
 
 
 
19
 
20
- # ── Embedder & memory RPC ──────────────────────────────────────────────────
21
- @st.cache_resource(show_spinner=False)
22
- def get_embedder():
23
- return SentenceTransformer("paraphrase-MiniLM-L3-v2")
24
- embedder = get_embedder()
25
-
26
- @st.cache_data(show_spinner=False)
27
- def fetch_mems(query, k=3):
28
- vec = embedder.encode(query).astype("float32").tolist()
29
- return supabase.rpc(
30
- "match_memories",
31
- {"query_embedding": vec, "match_count": k}
32
- ).execute().data
33
-
34
- def add_mem(speaker, text):
35
- vec = embedder.encode(text).astype("float32").tolist()
36
- supabase.table("memories").insert({
37
- "speaker": speaker,
38
- "text": text,
39
- "embedding": vec
40
- }).execute()
41
-
42
- # ── Model + tokenizer (adapter locally, tokenizer remote) ─────────────────
43
  @st.cache_resource(show_spinner=False)
44
  def load_generator():
45
- base_dir = os.path.dirname(__file__)
46
- LOCAL_REPO = os.path.join(base_dir, "models", "phi2-deeptalk-lora")
47
- OFFLOAD_DIR = os.path.join(base_dir, "offload")
48
- os.makedirs(OFFLOAD_DIR, exist_ok=True)
49
-
50
- # 1) Tokenizer from official Phi-2
51
  tokenizer = AutoTokenizer.from_pretrained(
52
- "microsoft/phi-2",
53
  trust_remote_code=True,
54
  padding_side="left",
55
- local_files_only=False # allow remote fetch (cached)
56
  )
57
  if tokenizer.pad_token_id is None:
58
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
59
 
60
- # 2) Load base model (quantized on GPU if available)
61
  if torch.cuda.is_available():
62
- bnb_config = BitsAndBytesConfig(
63
  load_in_4bit=True,
64
  bnb_4bit_quant_type="nf4",
65
  bnb_4bit_compute_dtype="float16",
66
  low_cpu_mem_usage=True,
67
  )
68
  base = AutoModelForCausalLM.from_pretrained(
69
- "microsoft/phi-2",
70
  trust_remote_code=True,
71
- quantization_config=bnb_config,
72
- device_map="auto",
73
- offload_folder=OFFLOAD_DIR,
74
- offload_state_dict=True,
75
  )
76
  else:
77
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
78
  base = AutoModelForCausalLM.from_pretrained(
79
- "microsoft/phi-2",
80
  trust_remote_code=True,
81
  torch_dtype=dtype,
82
- device_map="auto",
83
- offload_folder=OFFLOAD_DIR,
84
- offload_state_dict=True,
85
  )
86
 
87
- # 3) Resize embeddings to match tokenizer
88
  base.resize_token_embeddings(len(tokenizer))
 
 
89
 
90
- # 4) Load LoRA config & wrap base
91
- peft_config = LoraConfig.from_pretrained(
92
- LOCAL_REPO,
93
- local_files_only=True
94
- )
95
- model = get_peft_model(base, peft_config)
96
-
97
- # 5) Manually load adapter weights
98
- adapter_path = os.path.join(LOCAL_REPO, "adapter_model.safetensors")
99
- state_dict = safe_load(adapter_path)
100
  model.load_state_dict(state_dict, strict=False)
101
  model.eval()
102
 
103
- # 6) Build generation pipeline
104
  gen = pipeline(
105
  "text-generation",
106
  model=model,
107
  tokenizer=tokenizer,
108
- device_map="auto",
109
- max_new_tokens=32,
110
- do_sample=True,
111
- temperature=0.2,
112
- top_p=0.8,
113
  use_cache=True,
114
  return_full_text=False,
115
  )
@@ -117,49 +89,34 @@ def load_generator():
117
 
118
  tokenizer, generator = load_generator()
119
 
120
- # ── System prompt ───────────────────────────────────────────────────────────
121
- SYSTEM = (
122
- "You are a helpful assistant for DeepTalks with a base model as Phi-2 "
123
- "and fine tuned by Sourish for my domain specific role.\n"
124
- "My domain is assisting you within my expertise by listening to you, "
125
- "understanding you & supporting you.\n"
126
- "Answer **only** using the information in the memory below.\n"
127
- "If the answer is not in memory, reply: \"I don't know.\"\n"
128
- "Do NOT repeat any lines beginning with 'User:'.\n"
129
- )
130
-
131
  # ── Streamlit UI ──────────────────────────────────────────────────────────
132
  st.set_page_config(layout="centered")
133
  st.title("🧠 Memory-Aware Phi-2 Chat")
134
 
 
135
  if "history" not in st.session_state:
136
- st.session_state.history = [] # list of (role, message)
137
 
138
- # Render existing history
139
- for role, msg in st.session_state.history:
140
- st.chat_message("user" if role == "You" else "assistant").write(msg)
141
 
142
- # Input at bottom
143
  user_input = st.chat_input("Type your message...")
144
 
145
  if user_input:
146
- # Show & store user
147
  st.chat_message("user").write(user_input)
148
  st.session_state.history.append(("You", user_input))
149
- add_mem("user", user_input)
150
-
151
- # Fetch memories & build prompt
152
- mems = fetch_mems(user_input, k=3)
153
- mem_block = "\n".join(m["text"] for m in mems)
154
- prompt = f"""{SYSTEM}
155
 
156
- Memory:
157
- {mem_block}
 
 
158
 
159
- User: {user_input}
160
- Assistant:"""
161
 
162
- # Generate with spinner
163
  with st.spinner("Thinking..."):
164
  try:
165
  out = generator(prompt)[0]["generated_text"].strip()
@@ -167,7 +124,6 @@ Assistant:"""
167
  out = "Sorry, I encountered an error."
168
  st.error(f"Generation error: {e}")
169
 
170
- # Show & store assistant
171
  st.chat_message("assistant").write(out)
172
  st.session_state.history.append(("Bot", out))
173
- add_mem("assistant", out)
 
8
  BitsAndBytesConfig,
9
  )
10
  from peft import LoraConfig, get_peft_model
 
 
11
  from safetensors.torch import load_file as safe_load
12
 
13
+ # ── Configuration ──────────────────────────────────────────────────────────
14
+ MODEL_REPO = "models/phi2-deeptalk-lora"
15
+ BASE_MODEL = "microsoft/phi-2"
16
+ CONTEXT_TURNS = 7 # how many past messages to include
17
+ MAX_NEW_TOKENS = 32 # shorter = faster
18
+ TEMPERATURE = 0.0 # 0.0 = greedy
19
+ TOP_P = 1.0 # disable nucleus sampling
20
+ DEVICE_MAP = "auto"
21
 
22
+ SYSTEM = (
23
+ "You are a helpful assistant for DeepTalks with a base model Phi-2 "
24
+ "fine-tuned by Sourish for domain-specific support.\n"
25
+ "Base replies **only** on the context below. "
26
+ "If you don't know, say β€œI don't know.”\n"
27
+ )
28
+
29
+ # ── Model Loader ───────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @st.cache_resource(show_spinner=False)
31
  def load_generator():
32
+ # 1) Tokenizer (always from HuggingFace cache)
 
 
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
+ BASE_MODEL,
35
  trust_remote_code=True,
36
  padding_side="left",
 
37
  )
38
  if tokenizer.pad_token_id is None:
39
+ tokenizer.add_special_tokens({"pad_token":"[PAD]"})
40
 
41
+ # 2) Base model in 4-bit or fp16/32
42
  if torch.cuda.is_available():
43
+ bnb = BitsAndBytesConfig(
44
  load_in_4bit=True,
45
  bnb_4bit_quant_type="nf4",
46
  bnb_4bit_compute_dtype="float16",
47
  low_cpu_mem_usage=True,
48
  )
49
  base = AutoModelForCausalLM.from_pretrained(
50
+ BASE_MODEL,
51
  trust_remote_code=True,
52
+ quantization_config=bnb,
53
+ device_map=DEVICE_MAP,
 
 
54
  )
55
  else:
56
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
57
  base = AutoModelForCausalLM.from_pretrained(
58
+ BASE_MODEL,
59
  trust_remote_code=True,
60
  torch_dtype=dtype,
61
+ device_map=DEVICE_MAP,
 
 
62
  )
63
 
64
+ # 3) Resize & wrap LoRA
65
  base.resize_token_embeddings(len(tokenizer))
66
+ peft_config = LoraConfig.from_pretrained(MODEL_REPO, local_files_only=True)
67
+ model = get_peft_model(base, peft_config)
68
 
69
+ # 4) Load adapter weights (.safetensors)
70
+ adapter_file = os.path.join(MODEL_REPO, "adapter_model.safetensors")
71
+ state_dict = safe_load(adapter_file)
 
 
 
 
 
 
 
72
  model.load_state_dict(state_dict, strict=False)
73
  model.eval()
74
 
75
+ # 5) Build pipeline (greedy for speed)
76
  gen = pipeline(
77
  "text-generation",
78
  model=model,
79
  tokenizer=tokenizer,
80
+ device_map=DEVICE_MAP,
81
+ max_new_tokens=MAX_NEW_TOKENS,
82
+ do_sample=False,
83
+ temperature=TEMPERATURE,
84
+ top_p=TOP_P,
85
  use_cache=True,
86
  return_full_text=False,
87
  )
 
89
 
90
  tokenizer, generator = load_generator()
91
 
 
 
 
 
 
 
 
 
 
 
 
92
  # ── Streamlit UI ──────────────────────────────────────────────────────────
93
  st.set_page_config(layout="centered")
94
  st.title("🧠 Memory-Aware Phi-2 Chat")
95
 
96
+ # initialize history
97
  if "history" not in st.session_state:
98
+ st.session_state.history = [] # list of (role, text)
99
 
100
+ # render existing
101
+ for role, text in st.session_state.history:
102
+ st.chat_message("user" if role=="You" else "assistant").write(text)
103
 
104
+ # user input
105
  user_input = st.chat_input("Type your message...")
106
 
107
  if user_input:
108
+ # show user
109
  st.chat_message("user").write(user_input)
110
  st.session_state.history.append(("You", user_input))
 
 
 
 
 
 
111
 
112
+ # build context from last turns
113
+ recent = st.session_state.history[-CONTEXT_TURNS*2:] # each turn = 2 entries
114
+ ctx = "\n".join(f"{'User' if r=='You' else 'Assistant'}: {t}"
115
+ for r,t in recent)
116
 
117
+ prompt = f"{SYSTEM}\nContext:\n{ctx}\nUser: {user_input}\nAssistant:"
 
118
 
119
+ # generate
120
  with st.spinner("Thinking..."):
121
  try:
122
  out = generator(prompt)[0]["generated_text"].strip()
 
124
  out = "Sorry, I encountered an error."
125
  st.error(f"Generation error: {e}")
126
 
127
+ # show bot
128
  st.chat_message("assistant").write(out)
129
  st.session_state.history.append(("Bot", out))