File size: 4,351 Bytes
52bc809 fef32cf 52bc809 ba2f041 fef32cf bc2f0be fef32cf 433d28b fef32cf b7c1ede bc2f0be 55faf97 b965f65 b4573da d934644 fef32cf df4e3a8 4f73f4a 7fab575 fef32cf ce9f4c5 fef32cf 70fd1ee 52bc809 5002144 c283634 fef32cf c283634 fef32cf c283634 02f80ba c283634 a79070b 857744a ba2f041 fef32cf 4f73f4a 857744a a79070b fef32cf a79070b 4f73f4a fef32cf b67224f b7c1ede |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import streamlit as st
import torch
import logging
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from peft import PeftModel
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BASE_MODEL = "microsoft/phi-2"
ADAPTER_REPO = "sourize/phi2-memory-deeptalks"
CONTEXT_TURNS = 7
MAX_NEW_TOKENS = 128
OFFLOAD_DIR = "offload"
SYSTEM = (
"You are a helpful assistant for DeepTalks with base Phi-2\n"
"Fine-tuned by Sourish for making personal conversations.\n"
"Answer **only** using the conversation context below.\n"
"Do NOT output any lines beginning with 'User:' or 'Assistant:'.\n"
"If you don't know, say \"I don't know.\"\n"
)
@st.cache_resource(show_spinner=False)
def load_pipeline():
# 1) Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL, trust_remote_code=True, padding_side="left"
)
if tokenizer.pad_token_id is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# 2) Base model: 4-bit on CUDA, plain FP16/FP32 on CPU
if torch.cuda.is_available():
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
low_cpu_mem_usage=True,
)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
quantization_config=quant_config,
device_map="auto",
offload_folder=OFFLOAD_DIR,
offload_state_dict=True,
)
else:
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=dtype,
device_map="cpu", # force CPU
)
# 3) Resize + LoRA overlay
base.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(
base,
ADAPTER_REPO,
trust_remote_code=True,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=None,
)
model.eval()
# 4) Build generation pipeline
gen = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto" if torch.cuda.is_available() else None,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=0.7,
top_p=0.9,
use_cache=True,
return_full_text=False,
)
logging.info("Pipeline loaded.")
return gen
generator = load_pipeline()
# ββ Streamlit UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.set_page_config(layout="centered")
st.title("π§ DeepTalks")
st.markdown("β³ It takes time to generate responses since it's running on the CPU free tier")
st.subheader("Your personal AI Companion", divider='grey')
if "history" not in st.session_state:
st.session_state.history = []
for role, text in st.session_state.history:
st.chat_message("user" if role == "You" else "assistant").write(text)
user_input = st.chat_input("Your messageβ¦")
if user_input:
st.chat_message("user").write(user_input)
st.session_state.history.append(("You", user_input))
recent = st.session_state.history[-CONTEXT_TURNS*2:]
context = "\n".join(t for _, t in recent)
prompt = f"""{SYSTEM}
Context:
{context}
User: {user_input}
Assistant:"""
with st.spinner("Thinkingβ¦"):
try:
reply = generator(prompt)[0]["generated_text"].strip()
for marker in ["User:", "Assistant:"]:
if marker in reply:
reply = reply.split(marker)[0].strip()
if not reply:
reply = "Iβm sorry, I didnβt catch that. Could you rephrase?"
except Exception as e:
reply = "Iβm sorry, something went wrong."
st.error(f"Error: {e}")
st.chat_message("assistant").write(reply)
st.session_state.history.append(("Bot", reply))
|