sourize
commited on
Commit
Β·
acc9519
1
Parent(s):
b3d517e
Commit
Browse files
app.py
CHANGED
|
@@ -20,7 +20,6 @@ supabase = create_client(SUPA_URL, SUPA_KEY)
|
|
| 20 |
@st.cache_resource(show_spinner=False)
|
| 21 |
def get_embedder():
|
| 22 |
return SentenceTransformer("paraphrase-MiniLM-L3-v2")
|
| 23 |
-
|
| 24 |
embedder = get_embedder()
|
| 25 |
|
| 26 |
@st.cache_data(show_spinner=False)
|
|
@@ -42,16 +41,17 @@ def add_mem(speaker, text):
|
|
| 42 |
# ββ Model + tokenizer (adapter locally, tokenizer remote) βββββββββββββββββ
|
| 43 |
@st.cache_resource(show_spinner=False)
|
| 44 |
def load_generator():
|
| 45 |
-
base_dir
|
| 46 |
-
LOCAL_REPO
|
| 47 |
-
OFFLOAD_DIR= os.path.join(base_dir, "offload")
|
| 48 |
os.makedirs(OFFLOAD_DIR, exist_ok=True)
|
| 49 |
|
| 50 |
-
# 1) Tokenizer from official Phi-2
|
| 51 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 52 |
"microsoft/phi-2",
|
| 53 |
trust_remote_code=True,
|
| 54 |
-
padding_side="left"
|
|
|
|
| 55 |
)
|
| 56 |
if tokenizer.pad_token_id is None:
|
| 57 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
@@ -87,16 +87,21 @@ def load_generator():
|
|
| 87 |
base.resize_token_embeddings(len(tokenizer))
|
| 88 |
|
| 89 |
# 4) Load LoRA config & wrap base
|
| 90 |
-
peft_config = LoraConfig.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 91 |
model = get_peft_model(base, peft_config)
|
| 92 |
|
| 93 |
# 5) Manually load adapter weights
|
| 94 |
-
adapter_path = os.path.join(
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
model.load_state_dict(state_dict, strict=False)
|
| 97 |
model.eval()
|
| 98 |
|
| 99 |
-
# 6)
|
| 100 |
gen = pipeline(
|
| 101 |
"text-generation",
|
| 102 |
model=model,
|
|
@@ -109,7 +114,7 @@ def load_generator():
|
|
| 109 |
temperature=0.2,
|
| 110 |
top_p=0.8,
|
| 111 |
use_cache=True,
|
| 112 |
-
return_full_text=False
|
| 113 |
)
|
| 114 |
return tokenizer, gen
|
| 115 |
|
|
@@ -135,19 +140,20 @@ if "history" not in st.session_state:
|
|
| 135 |
|
| 136 |
# Render existing history
|
| 137 |
for role, msg in st.session_state.history:
|
| 138 |
-
st.chat_message("user" if role=="You" else "assistant").write(msg)
|
| 139 |
|
| 140 |
# Input at bottom
|
| 141 |
user_input = st.chat_input("Type your message...")
|
| 142 |
|
| 143 |
if user_input:
|
|
|
|
| 144 |
st.chat_message("user").write(user_input)
|
| 145 |
st.session_state.history.append(("You", user_input))
|
| 146 |
add_mem("user", user_input)
|
| 147 |
|
|
|
|
| 148 |
mems = fetch_mems(user_input, k=3)
|
| 149 |
mem_block = "\n".join(m["text"] for m in mems)
|
| 150 |
-
|
| 151 |
prompt = f"""{SYSTEM}
|
| 152 |
|
| 153 |
Memory:
|
|
@@ -156,6 +162,7 @@ Memory:
|
|
| 156 |
User: {user_input}
|
| 157 |
Assistant:"""
|
| 158 |
|
|
|
|
| 159 |
with st.spinner("Thinking..."):
|
| 160 |
try:
|
| 161 |
out = generator(prompt)[0]["generated_text"].strip()
|
|
@@ -163,6 +170,7 @@ Assistant:"""
|
|
| 163 |
out = "Sorry, I encountered an error."
|
| 164 |
st.error(f"Generation error: {e}")
|
| 165 |
|
|
|
|
| 166 |
st.chat_message("assistant").write(out)
|
| 167 |
st.session_state.history.append(("Bot", out))
|
| 168 |
add_mem("assistant", out)
|
|
|
|
| 20 |
@st.cache_resource(show_spinner=False)
|
| 21 |
def get_embedder():
|
| 22 |
return SentenceTransformer("paraphrase-MiniLM-L3-v2")
|
|
|
|
| 23 |
embedder = get_embedder()
|
| 24 |
|
| 25 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 41 |
# ββ Model + tokenizer (adapter locally, tokenizer remote) βββββββββββββββββ
|
| 42 |
@st.cache_resource(show_spinner=False)
|
| 43 |
def load_generator():
|
| 44 |
+
base_dir = os.path.dirname(__file__)
|
| 45 |
+
LOCAL_REPO = os.path.join(base_dir, "models", "phi2-memory-lora")
|
| 46 |
+
OFFLOAD_DIR = os.path.join(base_dir, "offload")
|
| 47 |
os.makedirs(OFFLOAD_DIR, exist_ok=True)
|
| 48 |
|
| 49 |
+
# 1) Tokenizer from official Phi-2
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 51 |
"microsoft/phi-2",
|
| 52 |
trust_remote_code=True,
|
| 53 |
+
padding_side="left",
|
| 54 |
+
local_files_only=False # allow remote fetch (cached)
|
| 55 |
)
|
| 56 |
if tokenizer.pad_token_id is None:
|
| 57 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
|
| 87 |
base.resize_token_embeddings(len(tokenizer))
|
| 88 |
|
| 89 |
# 4) Load LoRA config & wrap base
|
| 90 |
+
peft_config = LoraConfig.from_pretrained(
|
| 91 |
+
LOCAL_REPO,
|
| 92 |
+
local_files_only=True
|
| 93 |
+
)
|
| 94 |
model = get_peft_model(base, peft_config)
|
| 95 |
|
| 96 |
# 5) Manually load adapter weights
|
| 97 |
+
adapter_path = os.path.join(
|
| 98 |
+
LOCAL_REPO, "adapter_model", "pytorch_model.bin"
|
| 99 |
+
)
|
| 100 |
+
state_dict = torch.load(adapter_path, map_location="cpu")
|
| 101 |
model.load_state_dict(state_dict, strict=False)
|
| 102 |
model.eval()
|
| 103 |
|
| 104 |
+
# 6) Build generation pipeline
|
| 105 |
gen = pipeline(
|
| 106 |
"text-generation",
|
| 107 |
model=model,
|
|
|
|
| 114 |
temperature=0.2,
|
| 115 |
top_p=0.8,
|
| 116 |
use_cache=True,
|
| 117 |
+
return_full_text=False,
|
| 118 |
)
|
| 119 |
return tokenizer, gen
|
| 120 |
|
|
|
|
| 140 |
|
| 141 |
# Render existing history
|
| 142 |
for role, msg in st.session_state.history:
|
| 143 |
+
st.chat_message("user" if role == "You" else "assistant").write(msg)
|
| 144 |
|
| 145 |
# Input at bottom
|
| 146 |
user_input = st.chat_input("Type your message...")
|
| 147 |
|
| 148 |
if user_input:
|
| 149 |
+
# Show & store user
|
| 150 |
st.chat_message("user").write(user_input)
|
| 151 |
st.session_state.history.append(("You", user_input))
|
| 152 |
add_mem("user", user_input)
|
| 153 |
|
| 154 |
+
# Fetch memories & build prompt
|
| 155 |
mems = fetch_mems(user_input, k=3)
|
| 156 |
mem_block = "\n".join(m["text"] for m in mems)
|
|
|
|
| 157 |
prompt = f"""{SYSTEM}
|
| 158 |
|
| 159 |
Memory:
|
|
|
|
| 162 |
User: {user_input}
|
| 163 |
Assistant:"""
|
| 164 |
|
| 165 |
+
# Generate with spinner
|
| 166 |
with st.spinner("Thinking..."):
|
| 167 |
try:
|
| 168 |
out = generator(prompt)[0]["generated_text"].strip()
|
|
|
|
| 170 |
out = "Sorry, I encountered an error."
|
| 171 |
st.error(f"Generation error: {e}")
|
| 172 |
|
| 173 |
+
# Show & store assistant
|
| 174 |
st.chat_message("assistant").write(out)
|
| 175 |
st.session_state.history.append(("Bot", out))
|
| 176 |
add_mem("assistant", out)
|