sourize commited on
Commit
df4e3a8
Β·
verified Β·
1 Parent(s): 17d9700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -101
app.py CHANGED
@@ -1,131 +1,76 @@
1
  import os
2
  import streamlit as st
3
- import torch
4
- import logging
5
- from transformers import (
6
- pipeline,
7
- AutoTokenizer,
8
- AutoModelForCausalLM,
9
- BitsAndBytesConfig,
10
- )
11
- from peft import PeftModel
12
 
13
  # ── Configuration ──────────────────────────────────────────────────────────
14
- BASE_MODEL = "microsoft/phi-2"
15
- ADAPTER_REPO = "sourize/phi2-memory-lora"
16
- CONTEXT_TURNS = 6
17
- MAX_NEW_TOKENS = 128
18
- OFFLOAD_DIR = "offload"
19
-
20
- SYSTEM = (
21
- "You are a helpful assistant for DeepTalks with base Phi-2\n"
22
- "fine-tuned by Sourish for domain support.\n"
23
  "Answer **only** using the conversation context below.\n"
24
  "Do NOT output any lines beginning with 'User:' or 'Assistant:'.\n"
25
  "If you don't know, say \"I don't know.\"\n"
26
  )
27
 
28
- @st.cache_resource(show_spinner=False)
29
- def load_pipeline():
30
- # 1) Tokenizer
31
- tokenizer = AutoTokenizer.from_pretrained(
32
- BASE_MODEL, trust_remote_code=True, padding_side="left"
33
- )
34
- if tokenizer.pad_token_id is None:
35
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
36
-
37
- # 2) Base model: 4-bit on CUDA, plain FP16/FP32 on CPU
38
- if torch.cuda.is_available():
39
- quant_config = BitsAndBytesConfig(
40
- load_in_4bit=True,
41
- bnb_4bit_quant_type="nf4",
42
- bnb_4bit_compute_dtype="float16",
43
- low_cpu_mem_usage=True,
44
- )
45
- base = AutoModelForCausalLM.from_pretrained(
46
- BASE_MODEL,
47
- trust_remote_code=True,
48
- quantization_config=quant_config,
49
- device_map="auto",
50
- offload_folder=OFFLOAD_DIR,
51
- offload_state_dict=True,
52
- )
53
- else:
54
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
55
- base = AutoModelForCausalLM.from_pretrained(
56
- BASE_MODEL,
57
- trust_remote_code=True,
58
- torch_dtype=dtype,
59
- device_map="cpu", # force CPU
60
- )
61
-
62
- # 3) Resize + LoRA overlay
63
- base.resize_token_embeddings(len(tokenizer))
64
- model = PeftModel.from_pretrained(
65
- base,
66
- ADAPTER_REPO,
67
- trust_remote_code=True,
68
- device_map="auto" if torch.cuda.is_available() else None,
69
- torch_dtype=None,
70
- )
71
- model.eval()
72
-
73
- # 4) Build generation pipeline
74
- gen = pipeline(
75
- "text-generation",
76
- model=model,
77
- tokenizer=tokenizer,
78
- device_map="auto" if torch.cuda.is_available() else None,
79
- max_new_tokens=MAX_NEW_TOKENS,
80
- do_sample=True,
81
- temperature=0.7,
82
- top_p=0.9,
83
- use_cache=True,
84
- return_full_text=False,
85
  )
86
-
87
- logging.info("Pipeline loaded.")
88
- return gen
89
-
90
- generator = load_pipeline()
 
91
 
92
  # ── Streamlit UI ──────────────────────────────────────────────────────────
93
  st.set_page_config(layout="centered")
94
- st.title("🧠 DeepTalks")
95
- st.subheader("Your personal AI Companion", divider='grey')
96
 
97
  if "history" not in st.session_state:
98
- st.session_state.history = []
99
 
 
100
  for role, text in st.session_state.history:
101
- st.chat_message("user" if role == "You" else "assistant").write(text)
102
 
103
- user_input = st.chat_input("Your message…")
 
104
  if user_input:
105
  st.chat_message("user").write(user_input)
106
  st.session_state.history.append(("You", user_input))
107
 
 
108
  recent = st.session_state.history[-CONTEXT_TURNS*2:]
109
- context = "\n".join(t for _, t in recent)
110
- prompt = f"""{SYSTEM}
111
-
112
- Context:
113
- {context}
114
-
115
- User: {user_input}
116
- Assistant:"""
117
 
 
118
  with st.spinner("Thinking…"):
119
  try:
120
- reply = generator(prompt)[0]["generated_text"].strip()
121
- for marker in ["User:", "Assistant:"]:
122
- if marker in reply:
123
- reply = reply.split(marker)[0].strip()
124
- if not reply:
125
- reply = "I’m sorry, I didn’t catch that. Could you rephrase?"
126
  except Exception as e:
127
- reply = "I’m sorry, something went wrong."
128
- st.error(f"Error: {e}")
129
 
130
  st.chat_message("assistant").write(reply)
131
  st.session_state.history.append(("Bot", reply))
 
1
  import os
2
  import streamlit as st
3
+ from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
4
 
5
  # ── Configuration ──────────────────────────────────────────────────────────
6
+ HF_TOKEN = os.getenv("HF_TOKEN") # store your token in Space Secrets
7
+ MODEL_ID = "sourize/phi2-memory-lora"
8
+ CONTEXT_TURNS = 7
9
+ MAX_NEW_TOKENS = 128
10
+ SYSTEM_PROMPT = (
11
+ "You are a helpful assistant for DeepTalks with base Phi-2 "
12
+ "fine-tuned by Sourish.\n"
 
 
13
  "Answer **only** using the conversation context below.\n"
14
  "Do NOT output any lines beginning with 'User:' or 'Assistant:'.\n"
15
  "If you don't know, say \"I don't know.\"\n"
16
  )
17
 
18
+ # ── HF Inference client ─────────────────────────────────────────────────────
19
+ client = InferenceClient(token=HF_TOKEN)
20
+
21
+ def query_hf(prompt: str) -> str:
22
+ out = client.text_generation(
23
+ model=MODEL_ID,
24
+ inputs=prompt,
25
+ parameters={
26
+ "max_new_tokens": MAX_NEW_TOKENS,
27
+ "do_sample": True,
28
+ "temperature": 0.7,
29
+ "top_p": 0.9,
30
+ "return_full_text": False
31
+ },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
+ text = out.generated_text.strip()
34
+ # strip any stray markers
35
+ for marker in ["User:", "Assistant:"]:
36
+ if marker in text:
37
+ text = text.split(marker)[0].strip()
38
+ return text or "I don't know."
39
 
40
  # ── Streamlit UI ──────────────────────────────────────────────────────────
41
  st.set_page_config(layout="centered")
42
+ st.title("🧠 DeepTalks (Inference API)")
43
+ st.subheader("Your personal AI Companion")
44
 
45
  if "history" not in st.session_state:
46
+ st.session_state.history = [] # tuples of (role, text)
47
 
48
+ # render history
49
  for role, text in st.session_state.history:
50
+ st.chat_message("user" if role=="You" else "assistant").write(text)
51
 
52
+ # new input
53
+ user_input = st.chat_input("Type your message…")
54
  if user_input:
55
  st.chat_message("user").write(user_input)
56
  st.session_state.history.append(("You", user_input))
57
 
58
+ # build context
59
  recent = st.session_state.history[-CONTEXT_TURNS*2:]
60
+ ctx = "\n".join(text for _, text in recent)
61
+ prompt = (
62
+ f"{SYSTEM_PROMPT}\n\n"
63
+ f"Context:\n{ctx}\n\n"
64
+ f"User: {user_input}\nAssistant:"
65
+ )
 
 
66
 
67
+ # call HF Inference API
68
  with st.spinner("Thinking…"):
69
  try:
70
+ reply = query_hf(prompt)
 
 
 
 
 
71
  except Exception as e:
72
+ reply = "Error generating response."
73
+ st.error(e)
74
 
75
  st.chat_message("assistant").write(reply)
76
  st.session_state.history.append(("Bot", reply))