DeepTalks / app.py
sourize
Commit
d934644
raw
history blame
4.56 kB
import os
import streamlit as st
import torch
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model
from safetensors.torch import load_file as safe_load
# ── Configuration ──────────────────────────────────────────────────────────
MODEL_REPO = "models/phi2-deeptalk-lora"
BASE_MODEL = "microsoft/phi-2"
CONTEXT_TURNS = 7 # how many past messages to include
MAX_NEW_TOKENS = 32 # shorter = faster
TEMPERATURE = 0.0 # 0.0 = greedy
TOP_P = 1.0 # disable nucleus sampling
DEVICE_MAP = "auto"
SYSTEM = (
"You are a helpful assistant for DeepTalks with a base model Phi-2 "
"fine-tuned by Sourish for domain-specific support.\n"
"Base replies **only** on the context below. "
"If you don't know, say β€œI don't know.”\n"
)
# ── Model Loader ───────────────────────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_generator():
# 1) Tokenizer (always from HuggingFace cache)
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
padding_side="left",
)
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token":"[PAD]"})
# 2) Base model in 4-bit or fp16/32
if torch.cuda.is_available():
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
low_cpu_mem_usage=True,
)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
quantization_config=bnb,
device_map=DEVICE_MAP,
)
else:
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=dtype,
device_map=DEVICE_MAP,
)
# 3) Resize & wrap LoRA
base.resize_token_embeddings(len(tokenizer))
peft_config = LoraConfig.from_pretrained(MODEL_REPO, local_files_only=True)
model = get_peft_model(base, peft_config)
# 4) Load adapter weights (.safetensors)
adapter_file = os.path.join(MODEL_REPO, "adapter_model.safetensors")
state_dict = safe_load(adapter_file)
model.load_state_dict(state_dict, strict=False)
model.eval()
# 5) Build pipeline (greedy for speed)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map=DEVICE_MAP,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=TEMPERATURE,
top_p=TOP_P,
use_cache=True,
return_full_text=False,
)
return tokenizer, gen
tokenizer, generator = load_generator()
# ── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(layout="centered")
st.title("🧠 Memory-Aware Phi-2 Chat")
# initialize history
if "history" not in st.session_state:
st.session_state.history = [] # list of (role, text)
# render existing
for role, text in st.session_state.history:
st.chat_message("user" if role=="You" else "assistant").write(text)
# user input
user_input = st.chat_input("Type your message...")
if user_input:
# show user
st.chat_message("user").write(user_input)
st.session_state.history.append(("You", user_input))
# build context from last turns
recent = st.session_state.history[-CONTEXT_TURNS*2:] # each turn = 2 entries
ctx = "\n".join(f"{'User' if r=='You' else 'Assistant'}: {t}"
for r,t in recent)
prompt = f"{SYSTEM}\nContext:\n{ctx}\nUser: {user_input}\nAssistant:"
# generate
with st.spinner("Thinking..."):
try:
out = generator(prompt)[0]["generated_text"].strip()
except Exception as e:
out = "Sorry, I encountered an error."
st.error(f"Generation error: {e}")
# show bot
st.chat_message("assistant").write(out)
st.session_state.history.append(("Bot", out))