from __future__ import annotations
import os
import re
from typing import Any
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from env.config import ADAPTER_REPO_ID, MODEL_ID
# Keep one model instance warm after the first request.
_model: Any = None
_tokenizer: Any = None
_adapter_applied = False
# Reasoning-capable models sometimes emit hidden thinking tags.
_THINK_BLOCK_PATTERN = re.compile(r".*?", re.IGNORECASE | re.DOTALL)
_THINK_START_PATTERN = re.compile(r".*", re.IGNORECASE | re.DOTALL)
def clean_generated_text(text: str, max_sentences: int | None = None) -> str:
"""Removes hidden reasoning and trims rambling generated text."""
# Prefer visible content after a completed hidden-reasoning block.
if "" in text.lower():
text = re.split(r"", text, flags=re.IGNORECASE, maxsplit=1)[-1]
# Remove complete or dangling hidden-reasoning blocks.
text = _THINK_BLOCK_PATTERN.sub("", text)
text = _THINK_START_PATTERN.sub("", text)
if "=== EMOTIONS ===" in text:
text = text.rsplit("=== EMOTIONS ===", 1)[1]
text = f"=== EMOTIONS ==={text}"
# Drop accidental chat-template continuations.
for marker in ("<|im_end|>", "<|im_start|>", "\nUser:", "\nAssistant:"):
if marker in text:
text = text.split(marker, 1)[0]
# Normalize horizontal whitespace, but preserve newlines.
text = re.sub(r"[ \t]+", " ", text)
text = "\n".join(line.strip() for line in text.splitlines())
text = re.sub(r"\n{3,}", "\n\n", text).strip()
# Apply conversational meta-prefix stripping only for chat replies (when max_sentences is set)
if max_sentences is not None:
lines = text.splitlines()
cleaned_lines = []
for line in lines:
line_strip = line.strip()
if not line_strip:
continue
# Skip conversational meta-introductions
normalized_line = line_strip.lower().rstrip(".:")
if normalized_line in (
"here goes",
"here is a reply",
"here is the reply",
"sure",
"here's a reply",
"here's the reply",
"here is my reply",
"here is a cbt reflection",
"here is a cbt reflection",
"here is the reflection",
"cbt reflection",
):
continue
# Strip prefixes like "AI:", "Assistant:", "Coach:", "InnerSpace:", "Response:"
prefix_pattern = re.compile(
r"^(ai|assistant|coach|innerspace|response):\s*", re.IGNORECASE
)
line_strip = prefix_pattern.sub("", line_strip)
if line_strip:
cleaned_lines.append(line_strip)
text = "\n\n".join(cleaned_lines)
if max_sentences is None or not text:
return text
# Keep coach replies compact for the visible chat panel.
sentences = re.findall(r"[^.!?]+[.!?]?", text)
compact = "".join(sentences[:max_sentences]).strip()
return compact or text
def _format_chat_prompt(tokenizer: Any, messages: list[dict[str, str]]) -> str:
"""Formats chat input while disabling model-side hidden reasoning when supported."""
try:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
def get_model_and_tokenizer(log_lines: list[str]) -> tuple[Any, Any]:
"""Loads and caches the Hugging Face model and tokenizer lazily."""
global _adapter_applied, _model, _tokenizer
if _model is None:
# Load tokenizer before model so prompt formatting is ready.
log_lines.append(f"Loading tokenizer: {MODEL_ID}")
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
token=os.environ.get("HF_TOKEN"),
)
# Use bfloat16 on CUDA and float32 elsewhere for compatibility.
log_lines.append(f"Loading model: {MODEL_ID}")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=dtype,
low_cpu_mem_usage=True,
token=os.environ.get("HF_TOKEN"),
)
# Apply the CBT adapter when it is available.
log_lines.append(f"Loading LoRA adapter: {ADAPTER_REPO_ID}")
try:
_model = PeftModel.from_pretrained(
_model,
ADAPTER_REPO_ID,
token=os.environ.get("HF_TOKEN"),
)
_adapter_applied = True
log_lines.append("LoRA adapter applied.")
except Exception as adapter_error:
_adapter_applied = False
log_lines.append(
f"Warning: Could not load adapter ({adapter_error}). Using base model."
)
# Move the loaded model to GPU memory when ZeroGPU is active.
if torch.cuda.is_available():
log_lines.append("Moving model to CUDA.")
_model = _model.to("cuda")
else:
# Make cached-request diagnostics explicit without reloading weights.
adapter_status = "with LoRA adapter" if _adapter_applied else "without adapter"
log_lines.append(f"Using cached model {adapter_status}.")
return _model, _tokenizer
def run_model_inference(prompt: str) -> tuple[str, str]:
"""Executes inference using local hardware in the app runtime."""
log_lines: list[str] = []
try:
# Reuse the cached local model for the journal analysis.
log_lines.append("Initializing local model execution...")
model, tokenizer = get_model_and_tokenizer(log_lines)
device = str(model.device)
log_lines.append(f"Running local model execution on device: {device}...")
# Format the prompt with the model chat template.
messages = [{"role": "user", "content": prompt}]
text = _format_chat_prompt(tokenizer, messages)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
# Generate a structured CBT reflection.
log_lines.append("Generating reflection...")
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
do_sample=False,
)
# Decode only the generated assistant tokens.
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response: str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
0
]
response = clean_generated_text(response)
log_lines.append("Local model execution completed successfully.")
return response, "\n".join(log_lines)
except Exception as e:
# Keep private text local even when inference fails.
log_lines.append(f"Local model execution failed: {e}.")
log_lines.append("Inference stopped. No serverless fallback is configured.")
return "", "\n".join(log_lines)
def run_chat_inference(
history: list[dict[str, str]],
system_prompt: str,
) -> tuple[str, str]:
"""Executes stateful multi-turn chat generation using local hardware."""
log_lines: list[str] = []
try:
# Reuse the cached local model for follow-up coaching.
log_lines.append("Initializing local model for chat...")
model, tokenizer = get_model_and_tokenizer(log_lines)
device = str(model.device)
log_lines.append(f"Running local chat model on device: {device}...")
# Include the system prompt and visible chat history.
messages = [{"role": "system", "content": system_prompt}] + history
text = _format_chat_prompt(tokenizer, messages)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
# Keep chat replies shorter than the initial report.
log_lines.append("Generating chat response...")
generated_ids = model.generate(
**model_inputs,
max_new_tokens=96,
do_sample=False,
repetition_penalty=1.18,
no_repeat_ngram_size=4,
)
# Decode only the latest assistant response.
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response: str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
0
]
response = clean_generated_text(response, max_sentences=4)
if not response:
response = (
"That sounds like a painful thought to sit with. "
"Can we look at one concrete piece of evidence for and against it?"
)
log_lines.append("Local chat generation completed successfully.")
return response, "\n".join(log_lines)
except Exception as e:
# Preserve privacy by failing locally instead of routing to an API.
log_lines.append(f"Local chat execution failed: {e}.")
log_lines.append(
"Chat inference stopped. No serverless fallback is configured."
)
error_msg = (
"I'm sorry, local model execution is currently unavailable. "
"Please try again after the Space finishes loading the model."
)
return error_msg, "\n".join(log_lines)