Spaces:
Runtime error
Runtime error
| """ | |
| Boston School Finder β Chatbot core | |
| ==================================== | |
| Designed for Llama 3.1 8B Instruct. | |
| Instead of native tool calling (which 8B models handle unreliably), | |
| the model outputs a simple tag like: | |
| [TOOL: find_eligible_schools | grade_level=K2 | street_address=123 Main St | zip_code=02118] | |
| Our code detects the tag, parses it, executes the query, and feeds | |
| the results back for the model to summarize. | |
| Two tools: | |
| 1. find_eligible_schools β calls Avela API for eligible school IDs | |
| 2. filter_based_on_preferences β filters/ranks eligible schools by user preferences | |
| """ | |
| import json | |
| import re | |
| from huggingface_hub import InferenceClient | |
| from config import BASE_MODEL, MY_MODEL, HF_TOKEN, SYSTEM_PROMPT, MAX_ELIGIBLE_SCHOOLS_RETURNED, SYSTEM_PROMPT_FIND, SYSTEM_PROMPT_REG, SYSTEM_PROMPT_CONTACT | |
| from data.database import BPSDatabase | |
| from data.check_eligibility_tool import find_eligible_schools | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONSTANTS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MAX_TOOL_ROUNDS = 4 # max tool-call loops per user message | |
| MAX_TOOL_RESULT_ITEMS = 15 # truncate large result lists | |
| MAX_CLEAN_RETRIES = 2 # re-prompt attempts if output still has tags | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CHATBOT CLASS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Chatbot: | |
| """ | |
| Tag-based tool-calling chatbot designed for small (8B) models. | |
| The model outputs [TOOL: fn | arg=val] tags. Our code parses them, | |
| runs the query, and feeds results back as a system message. | |
| No native tool calling is used. | |
| """ | |
| # All recognized tool names | |
| TOOL_NAMES = { | |
| "find_eligible_schools", "filter_based_on_preferences", | |
| } | |
| # Regex to match [TOOL: function_name | arg=val | arg=val] | |
| TOOL_TAG_RE = re.compile( | |
| r'\[TOOL:\s*(\w+)' # function name | |
| r'((?:\s*\|\s*\w+=?[^|\]]*)*)' # optional | arg=val pairs | |
| r'\s*\]', | |
| re.IGNORECASE | |
| ) | |
| def __init__(self): | |
| model_id = MY_MODEL if MY_MODEL else BASE_MODEL | |
| self.client = InferenceClient(model=model_id, token=HF_TOKEN) | |
| self.db = BPSDatabase() | |
| self._eligible_ids = None # populated by find_eligible_schools, used by filter_based_on_preferences | |
| self._eligible_schools = [] | |
| self._eligible_provider_type_counts = dict() | |
| # ββ Parse [TOOL: ...] tags ββββββββββββββββββββββββββββββββ | |
| def _parse_tool_tag(cls, text): | |
| """ | |
| Parse a [TOOL: fn_name | arg=val | ...] tag from text. | |
| Returns (fn_name, args_dict) or None. | |
| """ | |
| if not text: | |
| return None | |
| match = cls.TOOL_TAG_RE.search(text) | |
| if not match: | |
| return None | |
| fn_name = match.group(1).strip().lower() | |
| if fn_name not in cls.TOOL_NAMES: | |
| return None | |
| raw_pairs = match.group(2).strip() | |
| args = {} | |
| if raw_pairs: | |
| for segment in raw_pairs.split("|"): | |
| segment = segment.strip() | |
| if not segment or "=" not in segment: | |
| continue | |
| key, _, val = segment.partition("=") | |
| key = key.strip() | |
| val = val.strip() | |
| if not val or val.lower() in ("null", "none"): | |
| continue | |
| # find_eligible_school requires arguments to be strings | |
| # # Try numeric conversion | |
| # try: | |
| # val = int(val) | |
| # except ValueError: | |
| # try: | |
| # val = float(val) | |
| # except ValueError: | |
| # pass # keep as string | |
| args[key] = val | |
| return (fn_name, args) | |
| def _has_tool_tag(cls, text): | |
| """Check if text contains a [TOOL: ...] tag.""" | |
| return cls.TOOL_TAG_RE.search(text or "") is not None | |
| def _strip_tool_tags(cls, text): | |
| """Remove [TOOL: ...] tags from text, keeping surrounding prose.""" | |
| if not text: | |
| return "" | |
| cleaned = cls.TOOL_TAG_RE.sub("", text) | |
| cleaned = re.sub(r'\n{3,}', '\n\n', cleaned) | |
| return cleaned.strip() | |
| def _contains_artifacts(cls, text): | |
| """ | |
| Check if text contains anything the user shouldn't see: | |
| tool tags, raw JSON blobs with tool names, function-call syntax. | |
| """ | |
| if not text: | |
| return False | |
| if cls._has_tool_tag(text): | |
| return True | |
| tool_names_pat = "|".join(cls.TOOL_NAMES) | |
| patterns = [ | |
| rf'\{{\s*"(?:name|function)"\s*:\s*"(?:{tool_names_pat})"', | |
| rf'"(?:arguments|parameters)"\s*:\s*\{{', | |
| r"<T>", | |
| ] | |
| for p in patterns: | |
| if re.search(p, text, re.IGNORECASE): | |
| return True | |
| return False | |
| def _clean_eligible_schools(self, eligible_schools): | |
| cleaned_eligible_schools = [] | |
| for school in eligible_schools: | |
| cleaned_school = dict() | |
| for key, value in school.items(): | |
| if value is not None and value != "": | |
| cleaned_school[key] = value | |
| cleaned_eligible_schools.append(cleaned_school) | |
| return cleaned_eligible_schools | |
| # ββ Tool execution ββββββββββββββββββββββββββββββββββββββββ | |
| def _execute_tool(self, fn_name, args): | |
| """Dispatch a tool call. Returns JSON string.""" | |
| try: | |
| if fn_name == "find_eligible_schools": | |
| print("[TOOL CALL] find_eligible_schools") | |
| if self._eligible_ids is None: | |
| print("Executing find_eligible_schools") | |
| result = find_eligible_schools(**args) | |
| if result.get("error"): | |
| return json.dumps({"error": result["error"]}) | |
| # Store eligible IDs on the instance for filter_based_on_preferences | |
| self._eligible_ids = [ | |
| str(s["id"]) for s in result.get("eligible_schools", []) | |
| ] | |
| self._eligible_schools = result.get("eligible_schools", []) | |
| self._eligible_provider_type_counts = result.get("eligible_provider_type_counts",{}) | |
| else: | |
| print("already stored eligible schools") | |
| # small list of eligible schools for now | |
| sample_eligible_schools = self._eligible_schools | |
| if len(sample_eligible_schools) > MAX_ELIGIBLE_SCHOOLS_RETURNED: | |
| sample_eligible_schools = sample_eligible_schools[:MAX_ELIGIBLE_SCHOOLS_RETURNED] | |
| # print(sample_eligible_schools) | |
| sample_eligible_schools = self._clean_eligible_schools(sample_eligible_schools) | |
| return json.dumps({ | |
| "eligible_count": len(self._eligible_ids), | |
| "sample_eligible_schools": sample_eligible_schools, | |
| "eligible_provider_type_counts": self._eligible_provider_type_counts | |
| }) | |
| elif fn_name == "filter_based_on_preferences": | |
| print("[TOOL CALL] filter_based_on_preferences") | |
| if not self._eligible_ids: | |
| return json.dumps({ | |
| "error": "No eligible schools found yet. Call find_eligible_schools first." | |
| }) | |
| result = self.db.filter_based_on_preferences( | |
| self._eligible_ids, **args | |
| ) | |
| return json.dumps(result, default=str) | |
| else: | |
| print("[TOOL CALL] unknown tool") | |
| return json.dumps({"error": f"Unknown tool: {fn_name}"}) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| # ββ Message building ββββββββββββββββββββββββββββββββββββββ | |
| def _build_messages(self, user_input, history=None, mode="Find School"): | |
| """ | |
| Build messages list from Gradio history + current input. | |
| Handles both Gradio 3.x (pair lists) and 4.x (dict lists). | |
| """ | |
| prompt_map = { | |
| "Find School": SYSTEM_PROMPT_FIND, | |
| "Registration Guide": SYSTEM_PROMPT_REG, | |
| "Contact Info": SYSTEM_PROMPT_CONTACT | |
| } | |
| # Get the right prompt, defaulting to FIND if something goes wrong | |
| active_system_prompt = prompt_map.get(mode, SYSTEM_PROMPT_FIND) | |
| messages = [{"role": "system", "content": active_system_prompt}] | |
| if history: | |
| if isinstance(history[0], dict): | |
| for msg in history: | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("user", "assistant") and content: | |
| messages.append({"role": role, "content": content}) | |
| else: | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": user_input}) | |
| return messages | |
| # ββ Main response loop ββββββββββββββββββββββββββββββββββββ | |
| def get_response(self, user_input, history=None, mode="Find School"): | |
| """ | |
| Generate a response to the user's message. | |
| Loop: | |
| 1. Call model (no native tools β plain text generation). | |
| 2. If output contains [TOOL: ...], parse + execute + inject | |
| results, then loop. | |
| 3. If output is plain text, sanitize and return. | |
| """ | |
| messages = self._build_messages(user_input, history, mode) | |
| for _round in range(MAX_TOOL_ROUNDS): | |
| response = self.client.chat_completion( | |
| messages=messages, | |
| max_tokens=2048, | |
| ) | |
| content = response.choices[0].message.content or "" | |
| # ββ Check for [TOOL: ...] tag ββββββββββββββββββββ | |
| parsed = self._parse_tool_tag(content) | |
| if parsed: | |
| fn_name, fn_args = parsed | |
| tool_output = self._execute_tool(fn_name, fn_args) | |
| # Record the assistant's tool-tag turn | |
| messages.append({"role": "assistant", "content": content}) | |
| # Feed results back as system context | |
| if fn_name == 'find_eligible_schools': | |
| messages.append({ | |
| "role": "user", | |
| "content": ( | |
| f"System Observation: Data from database: \n{tool_output}\n\n" | |
| "Using ONLY the data above, state the number of eligible schools, " | |
| "briefly categorize and summarize the sample school options visible," | |
| "and ask for user preferences on how to narrow down the options." | |
| "Do NOT include any [TOOL:] tags, JSON, or code." | |
| ), | |
| }) | |
| elif fn_name == 'filter_based_on_preferences': | |
| messages.append({ | |
| "role": "user", | |
| "content": ( | |
| f"System Observation: Data from database: \n{tool_output}\n\n" | |
| "Now respond to my previous message with a helpful response using ONLY " | |
| "the data above. Explain the filtered school options, the main trade-offs, " | |
| "and I could narrow further if needed." | |
| "Do NOT include any [TOOL:] tags, JSON, or code." | |
| ), | |
| }) | |
| continue # loop for the model's final answer | |
| # ββ No tool tag β candidate final answer βββββββββ | |
| clean = self._clean_output(content) | |
| print(messages) | |
| return self._sanitize(clean, messages, user_input) | |
| # ββ Exhausted rounds β force a plain reply βββββββββββ | |
| messages.append({ | |
| "role": "system", | |
| "content": ( | |
| "You must respond to the user now. Do not use any [TOOL:] tags. " | |
| "Answer using whatever information you have, or tell the user " | |
| "you need more details." | |
| ), | |
| }) | |
| response = self.client.chat_completion( | |
| messages=messages, | |
| max_tokens=2048, | |
| ) | |
| return self._clean_output(response.choices[0].message.content or "") | |
| # ββ Output cleaning βββββββββββββββββββββββββββββββββββββββ | |
| @staticmethod | |
| def _clean_output(text): | |
| """Basic cleanup: strip leftover tags, whitespace.""" | |
| if not text: | |
| return "" | |
| # Strip the reasoning block (everything between <think> and </think>) | |
| text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) | |
| text = re.sub(r"^\s*<[RT]>\s*", "", text, count=1) | |
| return text.strip() | |
| def _sanitize(self, text, messages, user_input): | |
| """ | |
| If final response still contains tool artifacts, | |
| re-prompt the model to rewrite cleanly. | |
| Falls back to stripping artifacts mechanically. | |
| """ | |
| for _ in range(MAX_CLEAN_RETRIES): | |
| if not self._contains_artifacts(text): | |
| return text | |
| messages.append({"role": "assistant", "content": text}) | |
| messages.append({ | |
| "role": "system", | |
| "content": ( | |
| "[REWRITE NEEDED] Your response contained [TOOL:] tags or " | |
| "raw data the user cannot see. Rewrite as a plain friendly " | |
| "message with no tags, no JSON, no code.\n" | |
| f"The user asked: \"{user_input}\"" | |
| ), | |
| }) | |
| retry = self.client.chat_completion( | |
| messages=messages, | |
| max_tokens=2048, | |
| ) | |
| text = self._clean_output(retry.choices[0].message.content or "") | |
| # Last resort: mechanically strip any remaining artifacts | |
| if self._contains_artifacts(text): | |
| text = self._strip_tool_tags(text) | |
| text = re.sub(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', '', text) | |
| text = re.sub(r'\n{3,}', '\n\n', text).strip() | |
| return text if text else ( | |
| "I'm sorry, I ran into a technical issue. Please try again, " | |
| "or contact a BPS Welcome Center at 617-635-9010 for help." | |
| ) |