File size: 4,351 Bytes
52bc809
 
fef32cf
 
 
 
 
 
 
 
 
52bc809
ba2f041
fef32cf
bc2f0be
fef32cf
 
 
433d28b
fef32cf
b7c1ede
bc2f0be
55faf97
b965f65
b4573da
d934644
 
fef32cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df4e3a8
4f73f4a
7fab575
fef32cf
ce9f4c5
fef32cf
70fd1ee
52bc809
5002144
c283634
fef32cf
 
c283634
fef32cf
c283634
02f80ba
c283634
a79070b
857744a
ba2f041
fef32cf
 
 
 
 
 
 
4f73f4a
857744a
a79070b
fef32cf
 
 
 
 
 
a79070b
4f73f4a
fef32cf
b67224f
b7c1ede
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import streamlit as st
import torch
import logging
from transformers import (
    pipeline,
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import PeftModel

# ── Configuration ──────────────────────────────────────────────────────────
BASE_MODEL     = "microsoft/phi-2"
ADAPTER_REPO   = "sourize/phi2-memory-deeptalks"
CONTEXT_TURNS  = 7
MAX_NEW_TOKENS = 128
OFFLOAD_DIR    = "offload"

SYSTEM = (
    "You are a helpful assistant for DeepTalks with base Phi-2\n"
    "Fine-tuned by Sourish for making personal conversations.\n"
    "Answer **only** using the conversation context below.\n"
    "Do NOT output any lines beginning with 'User:' or 'Assistant:'.\n"
    "If you don't know, say \"I don't know.\"\n"
)

@st.cache_resource(show_spinner=False)
def load_pipeline():
    # 1) Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL, trust_remote_code=True, padding_side="left"
    )
    if tokenizer.pad_token_id is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    # 2) Base model: 4-bit on CUDA, plain FP16/FP32 on CPU
    if torch.cuda.is_available():
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="float16",
            low_cpu_mem_usage=True,
        )
        base = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            trust_remote_code=True,
            quantization_config=quant_config,
            device_map="auto",
            offload_folder=OFFLOAD_DIR,
            offload_state_dict=True,
        )
    else:
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        base = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            trust_remote_code=True,
            torch_dtype=dtype,
            device_map="cpu",           # force CPU
        )

    # 3) Resize + LoRA overlay
    base.resize_token_embeddings(len(tokenizer))
    model = PeftModel.from_pretrained(
        base,
        ADAPTER_REPO,
        trust_remote_code=True,
        device_map="auto" if torch.cuda.is_available() else None,
        torch_dtype=None,
    )
    model.eval()

    # 4) Build generation pipeline
    gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device_map="auto" if torch.cuda.is_available() else None,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        use_cache=True,
        return_full_text=False,
    )

    logging.info("Pipeline loaded.")
    return gen

generator = load_pipeline()

# ── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(layout="centered")
st.title("🧠 DeepTalks")
st.markdown("⏳ It takes time to generate responses since it's running on the CPU free tier")
st.subheader("Your personal AI Companion", divider='grey')

if "history" not in st.session_state:
    st.session_state.history = []

for role, text in st.session_state.history:
    st.chat_message("user" if role == "You" else "assistant").write(text)

user_input = st.chat_input("Your message…")
if user_input:
    st.chat_message("user").write(user_input)
    st.session_state.history.append(("You", user_input))

    recent = st.session_state.history[-CONTEXT_TURNS*2:]
    context = "\n".join(t for _, t in recent)
    prompt  = f"""{SYSTEM}

Context:
{context}

User: {user_input}
Assistant:"""

    with st.spinner("Thinking…"):
        try:
            reply = generator(prompt)[0]["generated_text"].strip()
            for marker in ["User:", "Assistant:"]:
                if marker in reply:
                    reply = reply.split(marker)[0].strip()
            if not reply:
                reply = "I’m sorry, I didn’t catch that. Could you rephrase?"
        except Exception as e:
            reply = "I’m sorry, something went wrong."
            st.error(f"Error: {e}")

    st.chat_message("assistant").write(reply)
    st.session_state.history.append(("Bot", reply))