sourize commited on
Commit
857744a
Β·
1 Parent(s): d934644
Files changed (2) hide show
  1. app.py +39 -39
  2. requirements.txt +5 -6
app.py CHANGED
@@ -11,34 +11,30 @@ from peft import LoraConfig, get_peft_model
11
  from safetensors.torch import load_file as safe_load
12
 
13
  # ── Configuration ──────────────────────────────────────────────────────────
14
- MODEL_REPO = "models/phi2-deeptalk-lora"
15
  BASE_MODEL = "microsoft/phi-2"
16
- CONTEXT_TURNS = 7 # how many past messages to include
17
- MAX_NEW_TOKENS = 32 # shorter = faster
18
- TEMPERATURE = 0.0 # 0.0 = greedy
19
- TOP_P = 1.0 # disable nucleus sampling
20
  DEVICE_MAP = "auto"
21
 
22
  SYSTEM = (
23
- "You are a helpful assistant for DeepTalks with a base model Phi-2 "
24
- "fine-tuned by Sourish for domain-specific support.\n"
25
- "Base replies **only** on the context below. "
26
- "If you don't know, say β€œI don't know.”\n"
27
  )
28
 
29
- # ── Model Loader ───────────────────────────────────────────────────────────
30
  @st.cache_resource(show_spinner=False)
31
  def load_generator():
32
- # 1) Tokenizer (always from HuggingFace cache)
33
  tokenizer = AutoTokenizer.from_pretrained(
34
- BASE_MODEL,
35
- trust_remote_code=True,
36
- padding_side="left",
37
  )
38
  if tokenizer.pad_token_id is None:
39
- tokenizer.add_special_tokens({"pad_token":"[PAD]"})
40
 
41
- # 2) Base model in 4-bit or fp16/32
42
  if torch.cuda.is_available():
43
  bnb = BitsAndBytesConfig(
44
  load_in_4bit=True,
@@ -61,27 +57,25 @@ def load_generator():
61
  device_map=DEVICE_MAP,
62
  )
63
 
64
- # 3) Resize & wrap LoRA
65
  base.resize_token_embeddings(len(tokenizer))
66
- peft_config = LoraConfig.from_pretrained(MODEL_REPO, local_files_only=True)
67
  model = get_peft_model(base, peft_config)
68
 
69
  # 4) Load adapter weights (.safetensors)
70
- adapter_file = os.path.join(MODEL_REPO, "adapter_model.safetensors")
71
  state_dict = safe_load(adapter_file)
72
  model.load_state_dict(state_dict, strict=False)
73
  model.eval()
74
 
75
- # 5) Build pipeline (greedy for speed)
76
  gen = pipeline(
77
  "text-generation",
78
  model=model,
79
  tokenizer=tokenizer,
80
  device_map=DEVICE_MAP,
81
  max_new_tokens=MAX_NEW_TOKENS,
82
- do_sample=False,
83
- temperature=TEMPERATURE,
84
- top_p=TOP_P,
85
  use_cache=True,
86
  return_full_text=False,
87
  )
@@ -93,37 +87,43 @@ tokenizer, generator = load_generator()
93
  st.set_page_config(layout="centered")
94
  st.title("🧠 Memory-Aware Phi-2 Chat")
95
 
96
- # initialize history
97
  if "history" not in st.session_state:
98
  st.session_state.history = [] # list of (role, text)
99
 
100
- # render existing
101
  for role, text in st.session_state.history:
102
  st.chat_message("user" if role=="You" else "assistant").write(text)
103
 
104
- # user input
105
- user_input = st.chat_input("Type your message...")
106
-
107
  if user_input:
108
- # show user
109
  st.chat_message("user").write(user_input)
110
  st.session_state.history.append(("You", user_input))
111
 
112
- # build context from last turns
113
- recent = st.session_state.history[-CONTEXT_TURNS*2:] # each turn = 2 entries
114
- ctx = "\n".join(f"{'User' if r=='You' else 'Assistant'}: {t}"
115
- for r,t in recent)
 
 
 
 
 
 
116
 
117
- prompt = f"{SYSTEM}\nContext:\n{ctx}\nUser: {user_input}\nAssistant:"
 
118
 
119
- # generate
120
- with st.spinner("Thinking..."):
121
  try:
122
  out = generator(prompt)[0]["generated_text"].strip()
123
  except Exception as e:
124
- out = "Sorry, I encountered an error."
125
- st.error(f"Generation error: {e}")
126
 
127
- # show bot
128
  st.chat_message("assistant").write(out)
129
  st.session_state.history.append(("Bot", out))
 
11
  from safetensors.torch import load_file as safe_load
12
 
13
  # ── Configuration ──────────────────────────────────────────────────────────
 
14
  BASE_MODEL = "microsoft/phi-2"
15
+ ADAPTER_DIR = os.path.join(os.path.dirname(__file__), "models", "phi2-deeptalk-lora")
16
+ CONTEXT_TURNS = 6
17
+ MAX_NEW_TOKENS = 32
 
18
  DEVICE_MAP = "auto"
19
 
20
  SYSTEM = (
21
+ "You are a helpful assistant for DeepTalks with base Phi-2 fine-tuned "
22
+ "by Sourish for domain support.\n"
23
+ "Answer only using the conversation context below.\n"
24
+ "If you don’t know, say β€œI don't know.”\n"
25
  )
26
 
27
+ # ── Model loader (cached) ──────────────────────────────────────────────────
28
  @st.cache_resource(show_spinner=False)
29
  def load_generator():
30
+ # 1) Tokenizer (from official HF cache)
31
  tokenizer = AutoTokenizer.from_pretrained(
32
+ BASE_MODEL, trust_remote_code=True, padding_side="left"
 
 
33
  )
34
  if tokenizer.pad_token_id is None:
35
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
36
 
37
+ # 2) Base model (4-bit on GPU, else FP16/FP32)
38
  if torch.cuda.is_available():
39
  bnb = BitsAndBytesConfig(
40
  load_in_4bit=True,
 
57
  device_map=DEVICE_MAP,
58
  )
59
 
60
+ # 3) Resize embeddings & wrap LoRA
61
  base.resize_token_embeddings(len(tokenizer))
62
+ peft_config = LoraConfig.from_pretrained(ADAPTER_DIR, local_files_only=True)
63
  model = get_peft_model(base, peft_config)
64
 
65
  # 4) Load adapter weights (.safetensors)
66
+ adapter_file = os.path.join(ADAPTER_DIR, "adapter_model.safetensors")
67
  state_dict = safe_load(adapter_file)
68
  model.load_state_dict(state_dict, strict=False)
69
  model.eval()
70
 
71
+ # 5) Build a **greedy** pipeline for max speed
72
  gen = pipeline(
73
  "text-generation",
74
  model=model,
75
  tokenizer=tokenizer,
76
  device_map=DEVICE_MAP,
77
  max_new_tokens=MAX_NEW_TOKENS,
78
+ do_sample=False, # greedy
 
 
79
  use_cache=True,
80
  return_full_text=False,
81
  )
 
87
  st.set_page_config(layout="centered")
88
  st.title("🧠 Memory-Aware Phi-2 Chat")
89
 
90
+ # Initialize chat history
91
  if "history" not in st.session_state:
92
  st.session_state.history = [] # list of (role, text)
93
 
94
+ # Render past messages
95
  for role, text in st.session_state.history:
96
  st.chat_message("user" if role=="You" else "assistant").write(text)
97
 
98
+ # User input at bottom
99
+ user_input = st.chat_input("Your message…")
 
100
  if user_input:
101
+ # Show/store user turn
102
  st.chat_message("user").write(user_input)
103
  st.session_state.history.append(("You", user_input))
104
 
105
+ # Build context from last N turns
106
+ recent = st.session_state.history[-CONTEXT_TURNS*2:]
107
+ ctx = "\n".join(
108
+ f"{'User' if r=='You' else 'Assistant'}: {t}"
109
+ for r,t in recent
110
+ )
111
+ prompt = f"""{SYSTEM}
112
+
113
+ Context:
114
+ {ctx}
115
 
116
+ User: {user_input}
117
+ Assistant:"""
118
 
119
+ # Generate reply (spinner)
120
+ with st.spinner("Thinking…"):
121
  try:
122
  out = generator(prompt)[0]["generated_text"].strip()
123
  except Exception as e:
124
+ out = "I’m sorry, something went wrong."
125
+ st.error(f"Error: {e}")
126
 
127
+ # Show/store assistant
128
  st.chat_message("assistant").write(out)
129
  st.session_state.history.append(("Bot", out))
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  streamlit
2
- transformers>=4.30
3
- peft
4
- supabase
5
- sentence-transformers
6
- faiss-cpu
7
- bitsandbytes
 
1
  streamlit
2
+ transformers>=4.51
3
+ peft>=0.15.2
4
+ bitsandbytes # for 4-bit GPU speed
5
+ safetensors # for loading your adapter file
6
+ torch # your target cuda/cu version