sourize commited on
Commit
fef32cf
Β·
verified Β·
1 Parent(s): 11db01a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -41
app.py CHANGED
@@ -1,13 +1,23 @@
1
  import os
2
  import streamlit as st
3
- import requests
 
 
 
 
 
 
 
 
4
 
5
  # ── Configuration ──────────────────────────────────────────────────────────
6
- HF_TOKEN = os.getenv("HF_TOKEN") # read‐only token in Space secrets
7
- MODEL_ID = "sourize/phi2-memory-lora"
8
- CONTEXT_TURNS = 6
 
 
9
 
10
- SYSTEM_PROMPT = (
11
  "You are a helpful assistant for DeepTalks with base Phi-2\n"
12
  "fine-tuned by Sourish for domain support.\n"
13
  "Answer **only** using the conversation context below.\n"
@@ -15,61 +25,107 @@ SYSTEM_PROMPT = (
15
  "If you don't know, say \"I don't know.\"\n"
16
  )
17
 
18
- API_URL = f"https://api-inference.huggingface.co/pipeline/text-generation/{MODEL_ID}"
19
- HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
20
-
21
- def query_hf(prompt: str) -> str:
22
- payload = {
23
- "inputs": prompt,
24
- "parameters": {
25
- "max_new_tokens": 128,
26
- "do_sample": True,
27
- "temperature": 0.7,
28
- "top_p": 0.9,
29
- "return_full_text": False
30
- },
31
- "options": {"use_cache": False}
32
- }
33
- r = requests.post(API_URL, headers=HEADERS, json=payload, timeout=60)
34
- r.raise_for_status()
35
- data = r.json()
36
- text = data[0]["generated_text"].strip()
37
- for m in ("User:", "Assistant:"):
38
- if m in text:
39
- text = text.split(m)[0].strip()
40
- return text or "I don't know."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # ── Streamlit UI ──────────────────────────────────────────────────────────
43
  st.set_page_config(layout="centered")
44
- st.title("🧠 DeepTalks (Inference API)")
45
- st.subheader("Your personal AI Companion")
46
 
47
  if "history" not in st.session_state:
48
  st.session_state.history = []
49
 
50
- # Render chat history
51
- for role, msg in st.session_state.history:
52
- st.chat_message("user" if role=="You" else "assistant").write(msg)
53
 
54
- user_input = st.chat_input("Type your message…")
55
  if user_input:
56
  st.chat_message("user").write(user_input)
57
  st.session_state.history.append(("You", user_input))
58
 
59
  recent = st.session_state.history[-CONTEXT_TURNS*2:]
60
  context = "\n".join(t for _, t in recent)
61
- prompt = (
62
- f"{SYSTEM_PROMPT}\n\n"
63
- f"Context:\n{context}\n\n"
64
- f"User: {user_input}\nAssistant:"
65
- )
 
 
66
 
67
  with st.spinner("Thinking…"):
68
  try:
69
- reply = query_hf(prompt)
 
 
 
 
 
70
  except Exception as e:
71
- st.error(f"API error: {e}")
72
  reply = "I’m sorry, something went wrong."
 
73
 
74
  st.chat_message("assistant").write(reply)
75
  st.session_state.history.append(("Bot", reply))
 
1
  import os
2
  import streamlit as st
3
+ import torch
4
+ import logging
5
+ from transformers import (
6
+ pipeline,
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ BitsAndBytesConfig,
10
+ )
11
+ from peft import PeftModel
12
 
13
  # ── Configuration ──────────────────────────────────────────────────────────
14
+ BASE_MODEL = "microsoft/phi-2"
15
+ ADAPTER_REPO = "sourize/phi2-memory-lora"
16
+ CONTEXT_TURNS = 7
17
+ MAX_NEW_TOKENS = 128
18
+ OFFLOAD_DIR = "offload"
19
 
20
+ SYSTEM = (
21
  "You are a helpful assistant for DeepTalks with base Phi-2\n"
22
  "fine-tuned by Sourish for domain support.\n"
23
  "Answer **only** using the conversation context below.\n"
 
25
  "If you don't know, say \"I don't know.\"\n"
26
  )
27
 
28
+ @st.cache_resource(show_spinner=False)
29
+ def load_pipeline():
30
+ # 1) Tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ BASE_MODEL, trust_remote_code=True, padding_side="left"
33
+ )
34
+ if tokenizer.pad_token_id is None:
35
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
36
+
37
+ # 2) Base model: 4-bit on CUDA, plain FP16/FP32 on CPU
38
+ if torch.cuda.is_available():
39
+ quant_config = BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_quant_type="nf4",
42
+ bnb_4bit_compute_dtype="float16",
43
+ low_cpu_mem_usage=True,
44
+ )
45
+ base = AutoModelForCausalLM.from_pretrained(
46
+ BASE_MODEL,
47
+ trust_remote_code=True,
48
+ quantization_config=quant_config,
49
+ device_map="auto",
50
+ offload_folder=OFFLOAD_DIR,
51
+ offload_state_dict=True,
52
+ )
53
+ else:
54
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
55
+ base = AutoModelForCausalLM.from_pretrained(
56
+ BASE_MODEL,
57
+ trust_remote_code=True,
58
+ torch_dtype=dtype,
59
+ device_map="cpu", # force CPU
60
+ )
61
+
62
+ # 3) Resize + LoRA overlay
63
+ base.resize_token_embeddings(len(tokenizer))
64
+ model = PeftModel.from_pretrained(
65
+ base,
66
+ ADAPTER_REPO,
67
+ trust_remote_code=True,
68
+ device_map="auto" if torch.cuda.is_available() else None,
69
+ torch_dtype=None,
70
+ )
71
+ model.eval()
72
+
73
+ # 4) Build generation pipeline
74
+ gen = pipeline(
75
+ "text-generation",
76
+ model=model,
77
+ tokenizer=tokenizer,
78
+ device_map="auto" if torch.cuda.is_available() else None,
79
+ max_new_tokens=MAX_NEW_TOKENS,
80
+ do_sample=True,
81
+ temperature=0.7,
82
+ top_p=0.9,
83
+ use_cache=True,
84
+ return_full_text=False,
85
+ )
86
+
87
+ logging.info("Pipeline loaded.")
88
+ return gen
89
+
90
+ generator = load_pipeline()
91
 
92
  # ── Streamlit UI ──────────────────────────────────────────────────────────
93
  st.set_page_config(layout="centered")
94
+ st.title("🧠 DeepTalks")
95
+ st.subheader("Your personal AI Companion", divider='grey')
96
 
97
  if "history" not in st.session_state:
98
  st.session_state.history = []
99
 
100
+ for role, text in st.session_state.history:
101
+ st.chat_message("user" if role == "You" else "assistant").write(text)
 
102
 
103
+ user_input = st.chat_input("Your message…")
104
  if user_input:
105
  st.chat_message("user").write(user_input)
106
  st.session_state.history.append(("You", user_input))
107
 
108
  recent = st.session_state.history[-CONTEXT_TURNS*2:]
109
  context = "\n".join(t for _, t in recent)
110
+ prompt = f"""{SYSTEM}
111
+
112
+ Context:
113
+ {context}
114
+
115
+ User: {user_input}
116
+ Assistant:"""
117
 
118
  with st.spinner("Thinking…"):
119
  try:
120
+ reply = generator(prompt)[0]["generated_text"].strip()
121
+ for marker in ["User:", "Assistant:"]:
122
+ if marker in reply:
123
+ reply = reply.split(marker)[0].strip()
124
+ if not reply:
125
+ reply = "I’m sorry, I didn’t catch that. Could you rephrase?"
126
  except Exception as e:
 
127
  reply = "I’m sorry, something went wrong."
128
+ st.error(f"Error: {e}")
129
 
130
  st.chat_message("assistant").write(reply)
131
  st.session_state.history.append(("Bot", reply))