JerameeUC commited on
Commit
a48e2a2
·
1 Parent(s): e63200a

Works just fine tuning.

Browse files
Files changed (4) hide show
  1. app_storefront.py +49 -12
  2. core/memory.py +25 -13
  3. core/model.py +33 -11
  4. core/storefront.py +128 -77
app_storefront.py CHANGED
@@ -10,6 +10,32 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "core"))
10
  from core.model import model_generate, MODEL_NAME
11
  from core.memory import build_prompt_from_history
12
  from core.storefront import load_storefront, storefront_qna, extract_products, get_rules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # ---------------- Load data + safe fallbacks ----------------
15
  DATA = load_storefront() # may be None if storefront_data.json missing/empty
@@ -42,19 +68,7 @@ else:
42
  VENUE_RULES = FALLBACK_VENUE
43
  PARKING_RULES = FALLBACK_PARKING
44
 
45
- def clean_generation(text: str) -> str:
46
- return (text or "").strip()
47
 
48
- # ---------------- Chat logic ----------------
49
- def chat_pipeline(history, message, max_new_tokens=128, temperature=0.8, top_p=0.95):
50
- # 1) Use storefront facts first (reduces hallucinations)
51
- sf = storefront_qna(DATA, message) # <-- pass DATA!
52
- if sf:
53
- return sf
54
- # 2) Memory-aware prompt to keep context grounded
55
- prompt = build_prompt_from_history(history, message, k=4)
56
- gen = model_generate(prompt, max_new_tokens, temperature, top_p)
57
- return clean_generation(gen)
58
 
59
  # ---------------- UI ----------------
60
  CSS = """
@@ -171,5 +185,28 @@ with gr.Blocks(title="Storefront Chat", css=CSS) as demo:
171
  health_btn.click(_health_cb, inputs=[history_state], outputs=[history_state, chat, status_md])
172
  caps_btn.click(_caps_cb, inputs=[history_state], outputs=[history_state, chat])
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  if __name__ == "__main__":
175
  demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
10
  from core.model import model_generate, MODEL_NAME
11
  from core.memory import build_prompt_from_history
12
  from core.storefront import load_storefront, storefront_qna, extract_products, get_rules
13
+ from core.storefront import is_storefront_query
14
+
15
+ def chat_pipeline(history, message, max_new_tokens=96, temperature=0.7, top_p=0.9):
16
+ # 1) Try storefront facts first
17
+ sf = storefront_qna(DATA, message)
18
+ if sf:
19
+ return sf
20
+
21
+ # 2) If not a storefront query, offer guided help (no LLM)
22
+ if not is_storefront_query(message):
23
+ return (
24
+ "I can help with the graduation storefront. Examples:\n"
25
+ "- Parking rules, lots opening times\n"
26
+ "- Attire / dress code\n"
27
+ "- Cap & Gown details and pickup\n"
28
+ "- Parking passes (multiple allowed)\n"
29
+ "Ask one of those, and I’ll answer directly."
30
+ )
31
+
32
+ # 3) Otherwise, generate with memory and hard stops
33
+ prompt = build_prompt_from_history(history, message, k=4)
34
+ gen = model_generate(prompt, max_new_tokens, temperature, top_p)
35
+ return clean_generation(gen)
36
+
37
+ def clean_generation(text: str) -> str:
38
+ return (text or "").strip()
39
 
40
  # ---------------- Load data + safe fallbacks ----------------
41
  DATA = load_storefront() # may be None if storefront_data.json missing/empty
 
68
  VENUE_RULES = FALLBACK_VENUE
69
  PARKING_RULES = FALLBACK_PARKING
70
 
 
 
71
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # ---------------- UI ----------------
74
  CSS = """
 
185
  health_btn.click(_health_cb, inputs=[history_state], outputs=[history_state, chat, status_md])
186
  caps_btn.click(_caps_cb, inputs=[history_state], outputs=[history_state, chat])
187
 
188
+ def clean_generation(text: str) -> str:
189
+ s = (text or "").strip()
190
+
191
+ # If the prompt contained "Assistant:", keep only what comes after the last one
192
+ last = s.rfind("Assistant:")
193
+ if last != -1:
194
+ s = s[last + len("Assistant:"):].strip()
195
+
196
+ # If it accidentally continued into a new "User:" or instructions, cut there
197
+ cut_marks = ["\nUser:", "\nYOU ARE ANSWERING", "\nProducts:", "\nVenue rules:", "\nParking rules:"]
198
+ cut_positions = [s.find(m) for m in cut_marks if s.find(m) != -1]
199
+ if cut_positions:
200
+ s = s[:min(cut_positions)].strip()
201
+
202
+ # Collapse repeated lines like "Yes, multiple parking passes..." spam
203
+ lines, out = s.splitlines(), []
204
+ seen = set()
205
+ for ln in lines:
206
+ # dedupe only exact consecutive repeats; keep normal conversation lines
207
+ if not out or ln != out[-1]:
208
+ out.append(ln)
209
+ return "\n".join(out).strip()
210
+
211
  if __name__ == "__main__":
212
  demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
core/memory.py CHANGED
@@ -1,22 +1,34 @@
1
  # core/memory.py
 
 
 
 
 
 
 
 
2
  def build_prompt_from_history(history, user_text, k=4) -> str:
3
  """
4
- history is a list of [user, bot] pairs (Gradio Chatbot format).
5
- Keep a compact, factual system preface to ground the model.
6
  """
7
  lines = [
8
- "System: Answer questions about the university graduation storefront using the facts below.",
9
- "System: Be concise. If unsure, say what is known.",
10
- "Facts:",
11
- "- Cap & Gown Set (CG-SET): $59.00, tassel included; ships until 10 days before the event.",
12
- "- Parking Pass (PK-1): $10.00; multiple passes allowed per student.",
13
- "- Venue: formal attire recommended; no muscle shirts; no sagging pants.",
14
- "- Parking: no double parking; vehicles in handicap spaces will be towed.",
15
  ]
16
- for u, b in (history or [])[-k:]:
17
- if u: lines.append(f"User: {u}")
18
- if b: lines.append(f"Assistant: {b}")
 
 
 
 
 
 
 
 
 
 
19
  lines.append(f"User: {user_text}")
20
  lines.append("Assistant:")
21
  return "\n".join(lines)
22
-
 
1
  # core/memory.py
2
+
3
+ META_MARKERS = ("### Status:", "### Capabilities", "Status:", "Capabilities", "Model:", "Storefront JSON:")
4
+
5
+ def _is_meta(s: str | None) -> bool:
6
+ if not s: return False
7
+ ss = s.strip()
8
+ return any(m in ss for m in META_MARKERS)
9
+
10
  def build_prompt_from_history(history, user_text, k=4) -> str:
11
  """
12
+ history: list[[user, bot], ...] from Gradio Chatbot.
13
+ Keep prompt compact; exclude meta/diagnostic messages.
14
  """
15
  lines = [
16
+ "System: Answer questions about the university graduation storefront.",
17
+ "System: Be concise. If unsure, state what is known."
 
 
 
 
 
18
  ]
19
+
20
+ # Keep only the last k turns that aren't meta
21
+ kept = []
22
+ for u, b in (history or []):
23
+ if u and not _is_meta(u):
24
+ kept.append(("User", u))
25
+ if b and not _is_meta(b):
26
+ kept.append(("Assistant", b))
27
+ kept = kept[-(2*k):] # up to k exchanges
28
+
29
+ for role, text in kept:
30
+ lines.append(f"{role}: {text}")
31
+
32
  lines.append(f"User: {user_text}")
33
  lines.append("Assistant:")
34
  return "\n".join(lines)
 
core/model.py CHANGED
@@ -1,23 +1,45 @@
1
  # core/model.py
2
- import os
3
- from transformers import pipeline
4
 
5
  MODEL_NAME = os.getenv("HF_MODEL_GENERATION", "distilgpt2")
6
- _PIPE = None
7
 
8
- def get_pipe():
9
- global _PIPE
10
- if _PIPE is None:
11
- _PIPE = pipeline("text-generation", model=MODEL_NAME)
12
- return _PIPE
13
 
14
- def model_generate(prompt: str, max_new_tokens=128, temperature=0.8, top_p=0.95) -> str:
15
- out = get_pipe()(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  prompt,
17
  max_new_tokens=int(max_new_tokens),
18
  do_sample=True,
19
  temperature=float(temperature),
20
  top_p=float(top_p),
21
- pad_token_id=50256,
 
 
 
 
22
  )
23
  return out[0]["generated_text"]
 
1
  # core/model.py
2
+ import re, os
3
+ from transformers import pipeline, StoppingCriteria, StoppingCriteriaList
4
 
5
  MODEL_NAME = os.getenv("HF_MODEL_GENERATION", "distilgpt2")
6
+ _pipe = None
7
 
8
+ class StopOnMarkers(StoppingCriteria):
9
+ def __init__(self, tokenizer, stop_strs=("\nUser:", "\nSystem:", "\n###", "\nProducts:", "\nVenue rules:", "\nParking rules:")):
10
+ self.tokenizer = tokenizer
11
+ self.stop_ids = [tokenizer(s, add_special_tokens=False).input_ids for s in stop_strs]
 
12
 
13
+ def __call__(self, input_ids, scores, **kwargs):
14
+ # stop if any marker sequence just appeared at the end
15
+ for seq in self.stop_ids:
16
+ L = len(seq)
17
+ if L and len(input_ids[0]) >= L and input_ids[0][-L:].tolist() == seq:
18
+ return True
19
+ return False
20
+
21
+ def _get_pipe():
22
+ global _pipe
23
+ if _pipe is None:
24
+ _pipe = pipeline("text-generation", model=MODEL_NAME)
25
+ return _pipe
26
+
27
+ def model_generate(prompt, max_new_tokens=96, temperature=0.7, top_p=0.9):
28
+ pipe = _get_pipe()
29
+ tok = pipe.tokenizer
30
+
31
+ stop = StoppingCriteriaList([StopOnMarkers(tok)])
32
+
33
+ out = pipe(
34
  prompt,
35
  max_new_tokens=int(max_new_tokens),
36
  do_sample=True,
37
  temperature=float(temperature),
38
  top_p=float(top_p),
39
+ repetition_penalty=1.15, # discourages exact loops
40
+ no_repeat_ngram_size=3, # blocks short repeats like "Account/Account"
41
+ pad_token_id=tok.eos_token_id or 50256,
42
+ eos_token_id=tok.eos_token_id, # stop at EOS if model supports
43
+ stopping_criteria=stop,
44
  )
45
  return out[0]["generated_text"]
core/storefront.py CHANGED
@@ -1,5 +1,133 @@
 
1
  import json, os
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def _find_json():
4
  candidates = [
5
  os.path.join(os.getcwd(), "storefront_data.json"),
@@ -45,80 +173,3 @@ def extract_products(data):
45
  def get_rules(data):
46
  pol = (data or {}).get("policies", {}) or {}
47
  return pol.get("venue_rules", []), pol.get("parking_rules", [])
48
-
49
- def storefront_qna(data, user_text: str):
50
- """
51
- Lightweight rules: try exact single-word intents first, then faq,
52
- then rules/products lookup. Return None to allow LLM fallback.
53
- """
54
- if not user_text:
55
- return None
56
- t = user_text.strip().lower()
57
-
58
- # single-word catches
59
- if t in {"parking"}:
60
- _, pr = get_rules(data)
61
- if pr: return "Parking rules:\n- " + "\n- ".join(pr)
62
-
63
- # NEW: map 'wear' directly to venue rules to avoid LLM hallucinations
64
- if t in {"venue", "attire", "dress", "dress code", "wear"} or "what should i wear" in t:
65
- vr, _ = get_rules(data)
66
- if vr: return "Venue rules:\n- " + "\n- ".join(vr)
67
-
68
- if t in {"passes", "parking pass", "parking passes"}:
69
- return "Yes, multiple parking passes are allowed per student."
70
-
71
- # faq
72
- a = answer_faq(data, t)
73
- if a: return a
74
-
75
- # explicit rule asks
76
- if "parking" in t and "rule" in t:
77
- _, pr = get_rules(data)
78
- if pr: return "Parking rules:\n- " + "\n- ".join(pr)
79
- if ("venue" in t and "rule" in t) or "attire" in t or "dress code" in t:
80
- vr, _ = get_rules(data)
81
- if vr: return "Venue rules:\n- " + "\n- ".join(vr)
82
-
83
- # “lots open” style
84
- if "parking" in t and ("hours" in t or "time" in t or "open" in t):
85
- lots_open = (((data or {}).get("logistics") or {}).get("lots_open_hours_before") or 2)
86
- return f"Parking lots open {lots_open} hours before the ceremony."
87
-
88
- # products
89
- if "cap" in t or "gown" in t or "parking pass" in t or "product" in t:
90
- prods = extract_products(data)
91
- if prods:
92
- lines = []
93
- for p in prods:
94
- price = p["price"]
95
- price_str = f"${price:.2f}" if isinstance(price, (int,float)) else str(price)
96
- lines.append(f"{p['name']} — {price_str}: {p['notes']}")
97
- return "\n".join(lines)
98
-
99
- return None
100
-
101
- # app_storefront.py
102
-
103
- def clean_generation(text: str) -> str:
104
- s = (text or "").strip()
105
-
106
- # If the prompt contained "Assistant:", keep only what comes after the last one
107
- last = s.rfind("Assistant:")
108
- if last != -1:
109
- s = s[last + len("Assistant:"):].strip()
110
-
111
- # If it accidentally continued into a new "User:" or instructions, cut there
112
- cut_marks = ["\nUser:", "\nYOU ARE ANSWERING", "\nProducts:", "\nVenue rules:", "\nParking rules:"]
113
- cut_positions = [s.find(m) for m in cut_marks if s.find(m) != -1]
114
- if cut_positions:
115
- s = s[:min(cut_positions)].strip()
116
-
117
- # Collapse repeated lines like "Yes, multiple parking passes..." spam
118
- lines, out = s.splitlines(), []
119
- seen = set()
120
- for ln in lines:
121
- # dedupe only exact consecutive repeats; keep normal conversation lines
122
- if not out or ln != out[-1]:
123
- out.append(ln)
124
- return "\n".join(out).strip()
 
1
+ # core/storefront.py
2
  import json, os
3
 
4
+ def clean_generation(text: str) -> str:
5
+ s = (text or "").strip()
6
+
7
+ # Keep only text after the last "Assistant:"
8
+ last = s.rfind("Assistant:")
9
+ if last != -1:
10
+ s = s[last + len("Assistant:"):].strip()
11
+
12
+ # Cut at the first sign of a new turn or meta
13
+ cut_marks = ["\nUser:", "\nSystem:", "\n###", "\nProducts:", "\nVenue rules:", "\nParking rules:"]
14
+ cuts = [s.find(m) for m in cut_marks if s.find(m) != -1]
15
+ if cuts:
16
+ s = s[:min(cuts)].strip()
17
+
18
+ # Remove egregious token loops like "Account/Account/..."
19
+ s = re.sub(r"(?:\b([A-Z][a-zA-Z0-9_/.-]{2,})\b(?:\s*/\s*\1\b)+)", r"\1", s)
20
+
21
+ # Collapse consecutive duplicate lines
22
+ dedup = []
23
+ for ln in s.splitlines():
24
+ if not dedup or ln.strip() != dedup[-1].strip():
25
+ dedup.append(ln)
26
+ return "\n".join(dedup).strip()
27
+
28
+ HELP_KEYWORDS = {
29
+ "help", "assist", "assistance", "tips", "how do i", "what can you do",
30
+ "graduation help", "help me with graduation", "can you help me with graduation"
31
+ }
32
+
33
+ STORE_KEYWORDS = {
34
+ "cap", "gown", "parking", "pass", "passes", "attire", "dress",
35
+ "venue", "logistics", "shipping", "pickup", "lot", "lots", "arrival", "size", "sizing"
36
+ }
37
+
38
+ def is_storefront_query(text: str) -> bool:
39
+ t = (text or "").lower()
40
+ return any(k in t for k in STORE_KEYWORDS) or any(k in t for k in HELP_KEYWORDS)
41
+
42
+ def _get_lots_open_hours(data) -> int:
43
+ try:
44
+ return int(((data or {}).get("logistics") or {}).get("lots_open_hours_before") or 2)
45
+ except Exception:
46
+ return 2
47
+
48
+ # Main router (drop-in)
49
+ def storefront_qna(data, user_text: str) -> str | None:
50
+ """
51
+ Deterministic storefront answers first:
52
+ - single-word intents (parking / wear / passes)
53
+ - help/capability prompt
54
+ - FAQ (if you have answer_faq)
55
+ - explicit rules queries
56
+ - 'lots open' timing
57
+ - compact products list
58
+ Returns None to allow LLM fallback in your chat pipeline.
59
+ """
60
+ if not user_text:
61
+ return None
62
+ t = user_text.strip().lower()
63
+
64
+ # 1) Single-word / exact intents to avoid LLM hallucinations
65
+ if t in {"parking"}:
66
+ _, pr = get_rules(data)
67
+ if pr:
68
+ return "Parking rules:\n- " + "\n- ".join(pr)
69
+
70
+ # Map 'wear/attire' variants directly to venue rules
71
+ if t in {"venue", "attire", "dress", "dress code", "wear"} or "what should i wear" in t:
72
+ vr, _ = get_rules(data)
73
+ if vr:
74
+ return "Venue rules:\n- " + "\n- ".join(vr)
75
+
76
+ # Parking passes (multiple allowed)
77
+ if t in {"passes", "parking pass", "parking passes"}:
78
+ return "Yes, multiple parking passes are allowed per student."
79
+
80
+ # 2) Help / capability intent → deterministic guidance
81
+ if any(k in t for k in HELP_KEYWORDS):
82
+ return (
83
+ "I can help with the graduation storefront. Try:\n"
84
+ "- “What are the parking rules?”\n"
85
+ "- “Can I buy multiple parking passes?”\n"
86
+ "- “Is formal attire required?”\n"
87
+ "- “Where do I pick up the gown?”\n"
88
+ "- “When do lots open?”"
89
+ )
90
+
91
+ # 3) JSON-driven FAQ (if available)
92
+ try:
93
+ a = answer_faq(data, t)
94
+ if a:
95
+ return a
96
+ except Exception:
97
+ pass # answer_faq may not exist or data may be None
98
+
99
+ # 4) Explicit rules phrasing (keeps answers tight and consistent)
100
+ if "parking" in t and "rule" in t:
101
+ _, pr = get_rules(data)
102
+ if pr:
103
+ return "Parking rules:\n- " + "\n- ".join(pr)
104
+
105
+ if ("venue" in t and "rule" in t) or "attire" in t or "dress code" in t:
106
+ vr, _ = get_rules(data)
107
+ if vr:
108
+ return "Venue rules:\n- " + "\n- ".join(vr)
109
+
110
+ # 5) “When do lots open?” / hours / time
111
+ if "parking" in t and ("hours" in t or "time" in t or "open" in t):
112
+ lots_open = _get_lots_open_hours(data)
113
+ return f"Parking lots open {lots_open} hours before the ceremony."
114
+
115
+ # 6) Product info (cap/gown/parking pass)
116
+ if any(k in t for k in ("cap", "gown", "parking pass", "product", "item", "price")):
117
+ prods = extract_products(data)
118
+ if prods:
119
+ lines = []
120
+ for p in prods:
121
+ name = p.get("name", "Item")
122
+ price = p.get("price", p.get("price_usd", ""))
123
+ notes = p.get("notes", p.get("description", ""))
124
+ price_str = f"${price:.2f}" if isinstance(price, (int, float)) else str(price)
125
+ lines.append(f"{name} — {price_str}: {notes}")
126
+ return "\n".join(lines)
127
+
128
+ # No deterministic match → let the caller fall back to the LLM
129
+ return None
130
+
131
  def _find_json():
132
  candidates = [
133
  os.path.join(os.getcwd(), "storefront_data.json"),
 
173
  def get_rules(data):
174
  pol = (data or {}).get("policies", {}) or {}
175
  return pol.get("venue_rules", []), pol.get("parking_rules", [])