DeepTalks / app.py
sourize's picture
Update app.py
ce9f4c5 verified
import os
import streamlit as st
import torch
import logging
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
# ── Configuration ──────────────────────────────────────────────────────────
BASE_MODEL = "microsoft/phi-2"
ADAPTER_REPO = "sourize/phi2-memory-deeptalks"
CONTEXT_TURNS = 7
MAX_NEW_TOKENS = 128
OFFLOAD_DIR = "offload"
SYSTEM = (
"You are a helpful assistant for DeepTalks with base Phi-2\n"
"Fine-tuned by Sourish for making personal conversations.\n"
"Answer **only** using the conversation context below.\n"
"Do NOT output any lines beginning with 'User:' or 'Assistant:'.\n"
"If you don't know, say \"I don't know.\"\n"
)
@st.cache_resource(show_spinner=False)
def load_pipeline():
# 1) Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL, trust_remote_code=True, padding_side="left"
)
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# 2) Base model: 4-bit on CUDA, plain FP16/FP32 on CPU
if torch.cuda.is_available():
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
low_cpu_mem_usage=True,
)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
quantization_config=quant_config,
device_map="auto",
offload_folder=OFFLOAD_DIR,
offload_state_dict=True,
)
else:
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=dtype,
device_map="cpu", # force CPU
)
# 3) Resize + LoRA overlay
base.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(
base,
ADAPTER_REPO,
trust_remote_code=True,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=None,
)
model.eval()
# 4) Build generation pipeline
gen = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto" if torch.cuda.is_available() else None,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=0.7,
top_p=0.9,
use_cache=True,
return_full_text=False,
)
logging.info("Pipeline loaded.")
return gen
generator = load_pipeline()
# ── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(layout="centered")
st.title("🧠 DeepTalks")
st.markdown("⏳ It takes time to generate responses since it's running on the CPU free tier")
st.subheader("Your personal AI Companion", divider='grey')
if "history" not in st.session_state:
st.session_state.history = []
for role, text in st.session_state.history:
st.chat_message("user" if role == "You" else "assistant").write(text)
user_input = st.chat_input("Your message…")
if user_input:
st.chat_message("user").write(user_input)
st.session_state.history.append(("You", user_input))
recent = st.session_state.history[-CONTEXT_TURNS*2:]
context = "\n".join(t for _, t in recent)
prompt = f"""{SYSTEM}
Context:
{context}
User: {user_input}
Assistant:"""
with st.spinner("Thinking…"):
try:
reply = generator(prompt)[0]["generated_text"].strip()
for marker in ["User:", "Assistant:"]:
if marker in reply:
reply = reply.split(marker)[0].strip()
if not reply:
reply = "I’m sorry, I didn’t catch that. Could you rephrase?"
except Exception as e:
reply = "I’m sorry, something went wrong."
st.error(f"Error: {e}")
st.chat_message("assistant").write(reply)
st.session_state.history.append(("Bot", reply))