Spaces:
Running on Zero
Running on Zero
Codex
Refactor UI examples, add multi-turn and deflection dataset examples, and implement context checks
fe43097 | from __future__ import annotations | |
| import os | |
| import re | |
| from typing import Any | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from env.config import ADAPTER_REPO_ID, MODEL_ID | |
| # Keep one model instance warm after the first request. | |
| _model: Any = None | |
| _tokenizer: Any = None | |
| _adapter_applied = False | |
| # Reasoning-capable models sometimes emit hidden thinking tags. | |
| _THINK_BLOCK_PATTERN = re.compile(r"<think>.*?</think>", re.IGNORECASE | re.DOTALL) | |
| _THINK_START_PATTERN = re.compile(r"<think>.*", re.IGNORECASE | re.DOTALL) | |
| def clean_generated_text(text: str, max_sentences: int | None = None) -> str: | |
| """Removes hidden reasoning and trims rambling generated text.""" | |
| # Prefer visible content after a completed hidden-reasoning block. | |
| if "</think>" in text.lower(): | |
| text = re.split(r"</think>", text, flags=re.IGNORECASE, maxsplit=1)[-1] | |
| # Remove complete or dangling hidden-reasoning blocks. | |
| text = _THINK_BLOCK_PATTERN.sub("", text) | |
| text = _THINK_START_PATTERN.sub("", text) | |
| if "=== EMOTIONS ===" in text: | |
| text = text.rsplit("=== EMOTIONS ===", 1)[1] | |
| text = f"=== EMOTIONS ==={text}" | |
| # Drop accidental chat-template continuations. | |
| for marker in ("<|im_end|>", "<|im_start|>", "\nUser:", "\nAssistant:"): | |
| if marker in text: | |
| text = text.split(marker, 1)[0] | |
| # Normalize horizontal whitespace, but preserve newlines. | |
| text = re.sub(r"[ \t]+", " ", text) | |
| text = "\n".join(line.strip() for line in text.splitlines()) | |
| text = re.sub(r"\n{3,}", "\n\n", text).strip() | |
| # Apply conversational meta-prefix stripping only for chat replies (when max_sentences is set) | |
| if max_sentences is not None: | |
| lines = text.splitlines() | |
| cleaned_lines = [] | |
| for line in lines: | |
| line_strip = line.strip() | |
| if not line_strip: | |
| continue | |
| # Skip conversational meta-introductions | |
| normalized_line = line_strip.lower().rstrip(".:") | |
| if normalized_line in ( | |
| "here goes", | |
| "here is a reply", | |
| "here is the reply", | |
| "sure", | |
| "here's a reply", | |
| "here's the reply", | |
| "here is my reply", | |
| "here is a cbt reflection", | |
| "here is a cbt reflection", | |
| "here is the reflection", | |
| "cbt reflection", | |
| ): | |
| continue | |
| # Strip prefixes like "AI:", "Assistant:", "Coach:", "InnerSpace:", "Response:" | |
| prefix_pattern = re.compile( | |
| r"^(ai|assistant|coach|innerspace|response):\s*", re.IGNORECASE | |
| ) | |
| line_strip = prefix_pattern.sub("", line_strip) | |
| if line_strip: | |
| cleaned_lines.append(line_strip) | |
| text = "\n\n".join(cleaned_lines) | |
| if max_sentences is None or not text: | |
| return text | |
| # Keep coach replies compact for the visible chat panel. | |
| sentences = re.findall(r"[^.!?]+[.!?]?", text) | |
| compact = "".join(sentences[:max_sentences]).strip() | |
| return compact or text | |
| def _format_chat_prompt(tokenizer: Any, messages: list[dict[str, str]]) -> str: | |
| """Formats chat input while disabling model-side hidden reasoning when supported.""" | |
| try: | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except TypeError: | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| def get_model_and_tokenizer(log_lines: list[str]) -> tuple[Any, Any]: | |
| """Loads and caches the Hugging Face model and tokenizer lazily.""" | |
| global _adapter_applied, _model, _tokenizer | |
| if _model is None: | |
| # Load tokenizer before model so prompt formatting is ready. | |
| log_lines.append(f"Loading tokenizer: {MODEL_ID}") | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| # Use bfloat16 on CUDA and float32 elsewhere for compatibility. | |
| log_lines.append(f"Loading model: {MODEL_ID}") | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| # Apply the CBT adapter when it is available. | |
| log_lines.append(f"Loading LoRA adapter: {ADAPTER_REPO_ID}") | |
| try: | |
| _model = PeftModel.from_pretrained( | |
| _model, | |
| ADAPTER_REPO_ID, | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| _adapter_applied = True | |
| log_lines.append("LoRA adapter applied.") | |
| except Exception as adapter_error: | |
| _adapter_applied = False | |
| log_lines.append( | |
| f"Warning: Could not load adapter ({adapter_error}). Using base model." | |
| ) | |
| # Move the loaded model to GPU memory when ZeroGPU is active. | |
| if torch.cuda.is_available(): | |
| log_lines.append("Moving model to CUDA.") | |
| _model = _model.to("cuda") | |
| else: | |
| # Make cached-request diagnostics explicit without reloading weights. | |
| adapter_status = "with LoRA adapter" if _adapter_applied else "without adapter" | |
| log_lines.append(f"Using cached model {adapter_status}.") | |
| return _model, _tokenizer | |
| def run_model_inference(prompt: str) -> tuple[str, str]: | |
| """Executes inference using local hardware in the app runtime.""" | |
| log_lines: list[str] = [] | |
| try: | |
| # Reuse the cached local model for the journal analysis. | |
| log_lines.append("Initializing local model execution...") | |
| model, tokenizer = get_model_and_tokenizer(log_lines) | |
| device = str(model.device) | |
| log_lines.append(f"Running local model execution on device: {device}...") | |
| # Format the prompt with the model chat template. | |
| messages = [{"role": "user", "content": prompt}] | |
| text = _format_chat_prompt(tokenizer, messages) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
| # Generate a structured CBT reflection. | |
| log_lines.append("Generating reflection...") | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| # Decode only the generated assistant tokens. | |
| generated_ids = [ | |
| output_ids[len(input_ids) :] | |
| for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response: str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[ | |
| 0 | |
| ] | |
| response = clean_generated_text(response) | |
| log_lines.append("Local model execution completed successfully.") | |
| return response, "\n".join(log_lines) | |
| except Exception as e: | |
| # Keep private text local even when inference fails. | |
| log_lines.append(f"Local model execution failed: {e}.") | |
| log_lines.append("Inference stopped. No serverless fallback is configured.") | |
| return "", "\n".join(log_lines) | |
| def run_chat_inference( | |
| history: list[dict[str, str]], | |
| system_prompt: str, | |
| ) -> tuple[str, str]: | |
| """Executes stateful multi-turn chat generation using local hardware.""" | |
| log_lines: list[str] = [] | |
| try: | |
| # Reuse the cached local model for follow-up coaching. | |
| log_lines.append("Initializing local model for chat...") | |
| model, tokenizer = get_model_and_tokenizer(log_lines) | |
| device = str(model.device) | |
| log_lines.append(f"Running local chat model on device: {device}...") | |
| # Include the system prompt and visible chat history. | |
| messages = [{"role": "system", "content": system_prompt}] + history | |
| text = _format_chat_prompt(tokenizer, messages) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
| # Keep chat replies shorter than the initial report. | |
| log_lines.append("Generating chat response...") | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=96, | |
| do_sample=False, | |
| repetition_penalty=1.18, | |
| no_repeat_ngram_size=4, | |
| ) | |
| # Decode only the latest assistant response. | |
| generated_ids = [ | |
| output_ids[len(input_ids) :] | |
| for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response: str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[ | |
| 0 | |
| ] | |
| response = clean_generated_text(response, max_sentences=4) | |
| if not response: | |
| response = ( | |
| "That sounds like a painful thought to sit with. " | |
| "Can we look at one concrete piece of evidence for and against it?" | |
| ) | |
| log_lines.append("Local chat generation completed successfully.") | |
| return response, "\n".join(log_lines) | |
| except Exception as e: | |
| # Preserve privacy by failing locally instead of routing to an API. | |
| log_lines.append(f"Local chat execution failed: {e}.") | |
| log_lines.append( | |
| "Chat inference stopped. No serverless fallback is configured." | |
| ) | |
| error_msg = ( | |
| "I'm sorry, local model execution is currently unavailable. " | |
| "Please try again after the Space finishes loading the model." | |
| ) | |
| return error_msg, "\n".join(log_lines) | |