document_redaction_vlm / tools /llm_funcs.py
seanpedrickcase's picture
Sync: Adjustments on redaction overlay formatting
9e54ad2
import json
import os
import re
import time
from typing import List, Tuple
from urllib.parse import urlparse
import boto3
import requests
import spaces
from tools.config import (
MAX_SPACES_GPU_RUN_TIME,
PRINT_TRANSFORMERS_USER_PROMPT,
REPORT_LLM_OUTPUTS_TO_GUI,
VLM_DEFAULT_DO_SAMPLE,
)
# Import mock patches if in test mode
if os.environ.get("USE_MOCK_LLM") == "1" or os.environ.get("TEST_MODE") == "1":
try:
# Try to import and apply mock patches
import sys
# Add project root to sys.path so we can import test.mock_llm_calls
project_root = os.path.dirname(os.path.dirname(__file__))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# try:
# from test.mock_llm_calls import apply_mock_patches
# apply_mock_patches()
# except ImportError:
# # If mock module not found, continue without mocking
# pass
except Exception:
# If anything fails, continue without mocking
pass
try:
from google import genai as ai
from google.genai import types
except ImportError:
print(
"Warning: Google GenAI not found. Google GenAI functionality will not be available."
)
pass
from gradio import Progress
from huggingface_hub import hf_hub_download
try:
from openai import OpenAI
except ImportError:
print("Warning: OpenAI not found. OpenAI functionality will not be available.")
pass
from tqdm import tqdm
model_type = None # global variable setup
full_text = (
"" # Define dummy source text (full text) just to enable highlight function to load
)
# Global variables for PII detection model and tokenizer
# These are now used for all LLM model loading (both general and PII-specific)
_pii_model = None
_pii_tokenizer = None
_pii_assistant_model = None
# Import config variables with defaults for missing ones
# This allows llm_funcs.py to work even if some config variables don't exist
from tools.config import (
ASSISTANT_MODEL,
COMPILE_MODE,
COMPILE_TRANSFORMERS,
HF_TOKEN,
INFERENCE_SERVER_DISABLE_THINKING,
INT8_WITH_OFFLOAD_TO_CPU,
LLM_CONTEXT_LENGTH,
LLM_MAX_NEW_TOKENS,
LLM_MIN_P,
LLM_MODEL_DTYPE,
LLM_REPETITION_PENALTY,
LLM_RESET,
LLM_RETRY_ATTEMPTS,
LLM_SEED,
LLM_STOP_STRINGS,
LLM_STREAM,
LLM_TEMPERATURE,
LLM_THREADS,
LLM_TIMEOUT_WAIT,
LLM_TOP_K,
LLM_TOP_P,
LOAD_TRANSFORMERS_LLM_PII_MODEL_AT_START,
LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE,
LOCAL_TRANSFORMERS_LLM_PII_REPO_ID,
MULTIMODAL_PROMPT_FORMAT,
QUANTISE_TRANSFORMERS_LLM_MODELS,
REASONING_SUFFIX,
SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL,
SHOW_TRANSFORMERS_LLM_PII_DETECTION_OPTIONS,
SPECULATIVE_DECODING,
USE_LLAMA_SWAP,
USE_TRANSFORMERS_VLM_MODEL_AS_LLM,
VLM_DISABLE_QWEN3_5_THINKING,
VLM_QWEN3_5_NOTHINK_SUFFIX,
)
def _stringify_openai_message_content(content) -> str:
"""Normalize message.content from OpenAI-compatible APIs (str, null, or list of parts)."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for p in content:
if isinstance(p, dict):
t = p.get("text")
if t is None and p.get("type") == "text":
t = p.get("text", "")
if isinstance(t, str):
parts.append(t)
elif isinstance(p, str):
parts.append(p)
return "".join(parts)
return str(content)
def _extract_choice_message_text(choice: dict) -> str:
"""Extract assistant text from a chat-completions choice (handles reasoning-only / multimodal)."""
if not isinstance(choice, dict):
return ""
msg = choice.get("message") or {}
text = _stringify_openai_message_content(msg.get("content"))
if text and str(text).strip():
return text
for alt_key in ("reasoning_content", "reasoning"):
alt = msg.get(alt_key)
if isinstance(alt, str) and alt.strip():
return alt
legacy = choice.get("text")
if isinstance(legacy, str) and legacy.strip():
return legacy
return text or ""
def _report_llm_output_to_gui(text: str) -> None:
"""Report streamed LLM output to Gradio UI via gr.Info when REPORT_LLM_OUTPUTS_TO_GUI is True."""
if not REPORT_LLM_OUTPUTS_TO_GUI or not (text and str(text).strip()):
return
try:
import gradio as gr
gr.Info(text, duration=2)
except Exception:
# gr.Info may not be available (e.g. in worker process or CLI), ignore
pass
if isinstance(LLM_THREADS, str):
LLM_THREADS = int(LLM_THREADS)
max_tokens = LLM_MAX_NEW_TOKENS
temperature = LLM_TEMPERATURE
top_k = LLM_TOP_K
top_p = LLM_TOP_P
min_p = LLM_MIN_P
repetition_penalty = LLM_REPETITION_PENALTY
LLM_MAX_NEW_TOKENS: int = LLM_MAX_NEW_TOKENS
seed: int = LLM_SEED
reset: bool = LLM_RESET
stream: bool = LLM_STREAM
context_length: int = LLM_CONTEXT_LENGTH
speculative_decoding = SPECULATIVE_DECODING
if not LLM_THREADS:
threads = 1
else:
threads = LLM_THREADS
timeout_wait = LLM_TIMEOUT_WAIT
number_of_api_retry_attempts = LLM_RETRY_ATTEMPTS
class LocalLLMContextConfig:
"""Holds context length and GPU layer count for local transformers model loading."""
def __init__(self, n_ctx: int = context_length, n_gpu_layers: int = -1):
self.n_ctx = n_ctx
self.n_gpu_layers = n_gpu_layers
def update_gpu(self, new_value: int) -> None:
self.n_gpu_layers = new_value
def update_context(self, new_value: int) -> None:
self.n_ctx = new_value
# GPU and CPU context configs for load_model (CPU uses 0 GPU layers).
local_gpu_context = LocalLLMContextConfig(n_ctx=context_length, n_gpu_layers=-1)
local_cpu_context = LocalLLMContextConfig(n_ctx=context_length, n_gpu_layers=0)
class LocalLLMGenerationConfig:
def __init__(
self,
temperature=temperature,
top_k=top_k,
min_p=min_p,
top_p=top_p,
repeat_penalty=repetition_penalty,
seed=seed,
stream=stream,
max_tokens=LLM_MAX_NEW_TOKENS,
reset=reset,
):
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.repeat_penalty = repeat_penalty
self.seed = seed
self.max_tokens = max_tokens
self.stream = stream
self.reset = reset
def update_temp(self, new_value):
self.temperature = new_value
# ResponseObject class for AWS Bedrock calls
class ResponseObject:
def __init__(self, text, usage_metadata):
self.text = text
self.usage_metadata = usage_metadata
###
# LOCAL MODEL FUNCTIONS
###
def get_model_path(
repo_id=LOCAL_TRANSFORMERS_LLM_PII_REPO_ID,
model_filename="",
model_dir="",
hf_token=HF_TOKEN,
):
# Construct the expected local path
local_path = os.path.join(model_dir, model_filename)
print("local path for model load:", local_path)
try:
if os.path.exists(local_path):
print(f"Model already exists at: {local_path}")
return local_path
else:
if hf_token:
print("Downloading model from Hugging Face Hub with HF token")
downloaded_model_path = hf_hub_download(
repo_id=repo_id, token=hf_token, filename=model_filename
)
return downloaded_model_path
else:
print(
"No HF token found, downloading model from Hugging Face Hub without token"
)
downloaded_model_path = hf_hub_download(
repo_id=repo_id, filename=model_filename
)
return downloaded_model_path
except Exception as e:
print("Error loading model:", e)
raise Warning("Error loading model:", e)
def _normalize_huggingface_repo_id(repo_id: str) -> str:
"""
If repo_id is an http(s) URL for huggingface.co, return the org/model path segment.
Uses parsed host validation (not substring checks) to satisfy CodeQL py/incomplete-url-substring-sanitization.
"""
s = repo_id.strip()
lower = s.lower()
if not (lower.startswith("https://") or lower.startswith("http://")):
return repo_id
parsed = urlparse(s)
if parsed.scheme.lower() not in ("http", "https"):
return repo_id
host = (parsed.hostname or "").lower()
if host not in ("huggingface.co", "www.huggingface.co"):
return repo_id
path = parsed.path.strip("/")
if not path:
return repo_id
return path
def load_model(
local_model_type: str = None,
gpu_layers: int = -1,
max_context_length: int = context_length,
gpu_context: LocalLLMContextConfig = local_gpu_context,
cpu_context: LocalLLMContextConfig = local_cpu_context,
torch_device: str = "cpu",
repo_id=LOCAL_TRANSFORMERS_LLM_PII_REPO_ID,
model_filename="",
model_dir="",
compile_mode=COMPILE_MODE,
model_dtype=LLM_MODEL_DTYPE,
hf_token=HF_TOKEN,
speculative_decoding=speculative_decoding,
model=None,
tokenizer=None,
assistant_model=None,
):
"""
Load a model from Hugging Face Hub via the transformers package.
Args:
local_model_type (str): The type of local model to load.
gpu_layers (int): The number of GPU layers to offload to the GPU (-1 for default).
max_context_length (int): The maximum context length for the model.
gpu_context (LocalLLMContextConfig): Context config for GPU (n_ctx, n_gpu_layers).
cpu_context (LocalLLMContextConfig): Context config for CPU.
torch_device (str): The device to load the model on ("cuda" or "cpu").
repo_id (str): The Hugging Face repository ID where the model is located.
model_filename (str): The specific filename of the model to download from the repository.
model_dir (str): The local directory where the model will be stored or downloaded.
compile_mode (str): The compilation mode to use for the model.
model_dtype (str): The data type to use for the model.
hf_token (str): The Hugging Face token to use for the model.
speculative_decoding (bool): Whether to use speculative decoding.
model (transformers model): Optional pre-loaded model (skips loading if provided).
tokenizer (transformers tokenizer): Optional pre-loaded tokenizer.
assistant_model (transformers model): Optional assistant model for speculative decoding.
Returns:
tuple: (model, tokenizer, assistant_model).
"""
# If model is provided, validate that tokenizer is also provided and compatible
if model:
if tokenizer is None:
print(
"Warning: Model provided but tokenizer is None. Attempting to load matching tokenizer..."
)
# Try to determine model_id from model config
try:
if hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_id = model.config._name_or_path
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loaded matching tokenizer from {model_id}")
else:
print(
"Warning: Could not determine model source to load matching tokenizer"
)
except Exception as e:
print(f"Warning: Failed to load matching tokenizer: {e}")
return model, tokenizer, assistant_model
# Use LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE if local_model_type is not provided
if local_model_type is None:
local_model_type = LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE
if isinstance(repo_id, str):
repo_id = _normalize_huggingface_repo_id(repo_id)
print("Loading model:", local_model_type)
# Verify the device and cuda settings
# Check if CUDA is enabled
import torch
torch.cuda.empty_cache()
print("Is CUDA enabled? ", torch.cuda.is_available())
print("Is a CUDA device available on this computer?", torch.backends.cudnn.enabled)
if torch.cuda.is_available():
torch_device = "cuda"
print("CUDA version:", torch.version.cuda)
# try:
# os.system("nvidia-smi")
# except Exception as e:
# print("Could not print nvidia-smi settings due to:", e)
else:
torch_device = "cpu"
gpu_layers = 0
print("Running on device:", torch_device)
print("GPU layers assigned to cuda:", gpu_layers)
if not LLM_THREADS:
threads = torch.get_num_threads()
else:
threads = LLM_THREADS
print("CPU threads:", threads)
# GPU mode
if torch_device == "cuda":
torch.cuda.empty_cache()
gpu_context.update_gpu(gpu_layers)
gpu_context.update_context(max_context_length)
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
print("Loading model from transformers")
# Use the official model ID for Gemma 3 4B
model_id = repo_id
# 1. Set Data Type (dtype)
# For H200/Hopper: 'bfloat16'
# For RTX 3060/Ampere: 'float16'
dtype_str = model_dtype # os.environ.get("LLM_MODEL_DTYPE", "bfloat16").lower()
if dtype_str == "bfloat16":
torch_dtype = torch.bfloat16
elif dtype_str == "float16":
torch_dtype = torch.float16
elif dtype_str == "auto":
torch_dtype = "auto"
else:
torch_dtype = torch.float32 # A safe fallback
# 2. Set Compilation Mode
# 'max-autotune' is great for both but can be slow initially.
# 'reduce-overhead' is a faster alternative for compiling.
print("--- System Configuration ---")
print(f"Using model id: {model_id}")
print(f"Using dtype: {torch_dtype}")
print(f"Using compile mode: {compile_mode}")
print(f"Using quantization: {QUANTISE_TRANSFORMERS_LLM_MODELS}")
print("--------------------------\n")
# --- Load Tokenizer and Model Atomically ---
# Ensure both model and tokenizer are loaded from the same source
# If either fails, both should fail together to prevent mismatched pairs
try:
# Setup quantization config if enabled
quantization_config = None
if QUANTISE_TRANSFORMERS_LLM_MODELS:
if not torch.cuda.is_available():
print(
"Warning: Quantisation requires CUDA, but CUDA is not available."
)
print("Falling back to loading models without quantisation")
quantization_config = None
else:
if INT8_WITH_OFFLOAD_TO_CPU:
# This will be very slow. Requires at least 4GB of VRAM and 32GB of RAM
print(
"Using bitsandbytes for quantisation to 8 bits, with offloading to CPU"
)
max_memory = {0: "4GB", "cpu": "32GB"}
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
max_memory=max_memory,
llm_int8_enable_fp32_cpu_offload=True, # Note: if bitsandbytes has to offload to CPU, inference will be slow
)
else:
# For Gemma 4B, requires at least 6GB of VRAM
print("Using bitsandbytes for quantisation to 4 bits")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Use the modern NF4 quantisation for better performance
bnb_4bit_compute_dtype=torch_dtype,
# bnb_4bit_use_double_quant=True, # Optional: uses a second quantisation step to save even more memory
)
# Prepare load kwargs
# Match VLM behavior: always use device_map="auto" for better device handling
load_kwargs = {
# "max_seq_length": max_context_length,
"token": hf_token,
"device_map": "auto", # Always use device_map="auto" like VLM
}
if quantization_config is not None:
load_kwargs["quantization_config"] = quantization_config
print("Loading model with bitsandbytes quantisation")
else:
# Use "auto" dtype like VLM for better compatibility
load_kwargs["dtype"] = "auto" if model_dtype == "auto" else torch_dtype
print("Loading model without quantisation")
# Load tokenizer FIRST to validate the model_id is accessible
# This ensures we catch tokenizer errors before loading the (larger) model
print(f"Loading tokenizer from {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=hf_token,
trust_remote_code=True,
)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded successfully")
# Load model from the SAME model_id to ensure compatibility
if "qwen" in local_model_type.lower() and "3.5" in local_model_type.lower():
print(f"Loading Qwen 3.5 model from {model_id}...")
from transformers import (
Qwen3_5ForCausalLM,
)
model = Qwen3_5ForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
**load_kwargs,
)
elif (
"qwen" in local_model_type.lower() and "3 " in local_model_type.lower()
):
print(f"Loading Qwen 3 model from {model_id}...")
from transformers import Qwen3VLForConditionalGeneration
model = Qwen3VLForConditionalGeneration.from_pretrained(
model_id,
trust_remote_code=True,
**load_kwargs,
)
else:
print(f"Loading model from {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
**load_kwargs,
)
# Set model to evaluation mode (standard transformers approach)
# Note: With device_map="auto", don't manually move model - let it handle device placement
model.eval()
print("Model loaded successfully")
# Validate that model and tokenizer are from the same source
if hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_source = model.config._name_or_path
if hasattr(tokenizer, "name_or_path"):
tokenizer_source = tokenizer.name_or_path
if model_source != tokenizer_source and model_id not in [
model_source,
tokenizer_source,
]:
print(
f"Warning: Model source ({model_source}) and tokenizer source ({tokenizer_source}) may differ. Using model_id: {model_id}"
)
except Exception as e:
# If loading fails, ensure both model and tokenizer are None to prevent partial state
print(f"Error loading model and tokenizer: {e}")
model = None
tokenizer = None
raise RuntimeError(
f"Failed to load model and tokenizer from {model_id}: {e}"
) from e
# Compile the Model with the selected mode 🚀
if COMPILE_TRANSFORMERS:
try:
model = torch.compile(model, mode=compile_mode, fullgraph=False)
except Exception as e:
print(f"Could not compile model: {e}. Running in eager mode.")
print(
"Loading with",
gpu_context.n_gpu_layers,
"model layers sent to GPU and a maximum context length of",
gpu_context.n_ctx,
)
# CPU mode
else:
try:
from transformers import AutoTokenizer
model_id = repo_id
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=hf_token,
trust_remote_code=True,
)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loaded tokenizer from {model_id} for compatibility")
except Exception as e:
print(f"Warning: Could not load tokenizer: {e}")
tokenizer = None
print(
"Loading with",
cpu_context.n_gpu_layers,
"model layers sent to GPU and a maximum context length of",
cpu_context.n_ctx,
)
print("Finished loading model:", local_model_type)
print("GPU layers assigned to cuda:", gpu_layers)
# Load assistant model for speculative decoding if enabled
# Note: Assistant model typically shares the same tokenizer as the main model
# for speculative decoding, so we don't load a separate tokenizer for it
if speculative_decoding and torch_device == "cuda":
print("Loading assistant model for speculative decoding:", ASSISTANT_MODEL)
try:
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
)
# Setup quantization config for assistant model (same as main model)
assistant_quantization_config = None
if QUANTISE_TRANSFORMERS_LLM_MODELS and torch.cuda.is_available():
if INT8_WITH_OFFLOAD_TO_CPU:
max_memory = {0: "4GB", "cpu": "32GB"}
assistant_quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
max_memory=max_memory,
llm_int8_enable_fp32_cpu_offload=True,
)
else:
assistant_quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
)
# Prepare load kwargs for assistant model
assistant_load_kwargs = {
"token": hf_token,
}
if assistant_quantization_config is not None:
assistant_load_kwargs["quantization_config"] = (
assistant_quantization_config
)
assistant_load_kwargs["device_map"] = "auto"
print("Loading assistant model with bitsandbytes quantisation")
else:
assistant_load_kwargs["dtype"] = torch_dtype
print("Loading assistant model without quantisation")
# Load the assistant model from ASSISTANT_MODEL
# Note: Assistant model should be compatible with the main model's tokenizer
# for speculative decoding to work correctly
print(f"Loading assistant model from {ASSISTANT_MODEL}...")
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL, **assistant_load_kwargs
)
# For non-quantized assistant models, explicitly move to device (matching VLM behavior)
if assistant_quantization_config is None:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
assistant_model = assistant_model.to(device)
# Validate that assistant model can work with the main tokenizer
# For speculative decoding, both models should use compatible tokenizers
if hasattr(assistant_model, "config") and hasattr(
assistant_model.config, "_name_or_path"
):
assistant_source = assistant_model.config._name_or_path
if hasattr(tokenizer, "name_or_path"):
tokenizer_source = tokenizer.name_or_path
if assistant_source != tokenizer_source:
print(
f"Warning: Assistant model ({assistant_source}) and tokenizer ({tokenizer_source}) are from different sources."
)
print(
"This may cause issues with speculative decoding. Ensure they are compatible."
)
# Compile the assistant model if compilation is enabled
if COMPILE_TRANSFORMERS:
try:
assistant_model = torch.compile(
assistant_model, mode=compile_mode, fullgraph=False
)
except Exception as e:
print(
f"Could not compile assistant model: {e}. Running in eager mode."
)
print("Successfully loaded assistant model for speculative decoding")
print("Note: Assistant model uses the same tokenizer as the main model")
except Exception as e:
print(f"Error loading assistant model: {e}")
assistant_model = None
else:
assistant_model = None
return model, tokenizer, assistant_model
# Initialize PII model at startup if configured (even if SHOW_TRANSFORMERS_LLM_PII_DETECTION_OPTIONS is False)
# This allows PII model to be loaded independently for PII detection tasks
if (
LOAD_TRANSFORMERS_LLM_PII_MODEL_AT_START
and SHOW_TRANSFORMERS_LLM_PII_DETECTION_OPTIONS
):
try:
print("Loading local PII model:", LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE)
_pii_model, _pii_tokenizer, _pii_assistant_model = load_model(
local_model_type=LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE,
max_context_length=context_length,
gpu_context=local_gpu_context,
cpu_context=local_cpu_context,
repo_id=LOCAL_TRANSFORMERS_LLM_PII_REPO_ID,
model_filename="",
model_dir="",
compile_mode=COMPILE_MODE,
model_dtype=LLM_MODEL_DTYPE,
hf_token=HF_TOKEN,
model=_pii_model,
tokenizer=_pii_tokenizer,
assistant_model=_pii_assistant_model,
)
except Exception as e:
print(f"Warning: Could not load PII model at startup: {e}")
print("PII model will be loaded on-demand when needed.")
@spaces.GPU(duration=MAX_SPACES_GPU_RUN_TIME)
def call_transformers_model(
prompt: str,
system_prompt: str,
gen_config: LocalLLMGenerationConfig,
model=_pii_model,
tokenizer=_pii_tokenizer,
assistant_model=_pii_assistant_model,
speculative_decoding=speculative_decoding,
use_vlm_safe_generation=VLM_DEFAULT_DO_SAMPLE,
):
"""
This function sends a request to a transformers model with the given prompt, system prompt, and generation configuration.
When use_vlm_safe_generation is True (e.g. VLM model used for LLM tasks), uses greedy decoding to avoid
sampling-related CUDA errors (e.g. invalid probability tensor in multinomial).
"""
import torch
from transformers import TextStreamer
# Custom streamer that reports streamed output to gr.Info when REPORT_LLM_OUTPUTS_TO_GUI is True
class _LLMGUIStreamer(TextStreamer):
def __init__(self, tokenizer, skip_prompt=True):
super().__init__(tokenizer, skip_prompt=skip_prompt)
self._line_buffer = ""
def on_finalized_text(self, text, stream_end=False):
super().on_finalized_text(text, stream_end)
if not REPORT_LLM_OUTPUTS_TO_GUI:
return
self._line_buffer += text
if "\n" in text or stream_end:
parts = self._line_buffer.split("\n")
for line in parts[:-1]:
if line.strip():
_report_llm_output_to_gui(line)
self._line_buffer = parts[-1] if parts else ""
if stream_end and self._line_buffer.strip():
_report_llm_output_to_gui(self._line_buffer)
# Load model and tokenizer together to ensure they're from the same source
# This prevents mismatches that could occur if they're loaded separately
if model is None or tokenizer is None:
print("Model not found. Loading model and tokenizer...")
# Use get_model_and_tokenizer() to ensure both are loaded atomically
# This is safer than calling get_pii_model() and get_pii_tokenizer() separately
loaded_model, loaded_tokenizer, assistant_model = load_model()
if model is None:
model = loaded_model
if tokenizer is None:
tokenizer = loaded_tokenizer
# if assistant_model is None and speculative_decoding:
# assistant_model = # get_assistant_model()
if model is None or tokenizer is None:
raise ValueError(
"No model or tokenizer available. Either pass them as parameters or ensure LOAD_TRANSFORMERS_LLM_PII_MODEL_AT_START is True."
)
# Apply reasoning suffix to prompt if configured
if REASONING_SUFFIX and REASONING_SUFFIX.strip():
prompt = f"{prompt} {REASONING_SUFFIX}".strip()
# When using VLM as LLM with Qwen3.5 thinking disabled, we append <think></think> after the generation
# prompt so the model continues with the answer (avoids continue_final_message which can fail
# when the chat template does not include the final assistant message in the rendered string).
add_nothink_assistant_turn = (
VLM_DISABLE_QWEN3_5_THINKING
and "Qwen 3.5" in LOCAL_TRANSFORMERS_LLM_PII_MODEL_CHOICE
) or (
VLM_DISABLE_QWEN3_5_THINKING
and USE_TRANSFORMERS_VLM_MODEL_AS_LLM
and (
"Qwen 3.5" in SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL
or "Qwen3.5" in SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL
)
)
# 1. Define the conversation as a list of dictionaries
# Note: The multimodal format [{"type": "text", "text": text}] is only needed for actual multimodal models
# with images/videos. For text-only content, even multimodal models expect plain strings.
# Check if system_prompt is meaningful (not empty/None)
has_system_prompt = system_prompt and str(system_prompt).strip()
# Always use string format for text-only content, regardless of MULTIMODAL_PROMPT_FORMAT setting
# MULTIMODAL_PROMPT_FORMAT should only be used when you actually have multimodal inputs (images, etc.)
if MULTIMODAL_PROMPT_FORMAT:
conversation = []
if has_system_prompt:
conversation.append(
{
"role": "system",
"content": [{"type": "text", "text": str(system_prompt)}],
}
)
conversation.append(
{"role": "user", "content": [{"type": "text", "text": str(prompt)}]}
)
else:
conversation = []
if has_system_prompt:
conversation.append({"role": "system", "content": str(system_prompt)})
conversation.append({"role": "user", "content": str(prompt)})
if PRINT_TRANSFORMERS_USER_PROMPT:
print("System prompt:", system_prompt)
print("User prompt:", prompt)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
if assistant_model is not None:
assistant_model = assistant_model.to(device)
if PRINT_TRANSFORMERS_USER_PROMPT:
print("Model device:", device)
print("Model device type:", type(device))
try:
# Try applying chat template with system prompt (if present)
# Create inputs dict like VLM does - this allows model to handle device placement automatically
# From transformers v5, apply_chat_template returns BatchEncoding; extract input_ids tensor
_encoded = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
)
input_ids = (
_encoded["input_ids"].to(device)
if hasattr(_encoded, "keys")
else _encoded.to(device)
)
if PRINT_TRANSFORMERS_USER_PROMPT:
print("Input IDs:", input_ids)
print("Rendered prompt:")
rendered = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False,
)
print(rendered)
print("-" * 50)
except (TypeError, KeyError, IndexError, ValueError) as e:
# If chat template fails, try without system prompt (some models don't support it)
if has_system_prompt:
print(
f"Chat template failed with system prompt ({e}), trying without system prompt..."
)
# Try again with only user prompt
user_only_conversation = [{"role": "user", "content": str(prompt)}]
try:
_encoded = tokenizer.apply_chat_template(
user_only_conversation,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
)
input_ids = (
_encoded["input_ids"].to(device)
if hasattr(_encoded, "keys")
else _encoded.to(device)
)
if PRINT_TRANSFORMERS_USER_PROMPT:
print("Input IDs:", input_ids)
print("Rendered prompt (without system):")
rendered = tokenizer.apply_chat_template(
user_only_conversation,
add_generation_prompt=True,
tokenize=False,
)
print(rendered)
print("-" * 50)
except Exception as e2:
print(
f"Chat template failed without system prompt ({e2}), using manual tokenization"
)
# Combine system and user prompts manually as fallback
full_prompt = (
f"{system_prompt}\n\n{prompt}" if has_system_prompt else prompt
)
# Tokenize manually with special tokens (tokenizer() returns BatchEncoding; extract tensor)
encoded = tokenizer(
full_prompt, return_tensors="pt", add_special_tokens=True
)
input_ids = encoded["input_ids"].to(device)
else:
# No system prompt, but chat template still failed - use manual tokenization
print(f"Chat template failed ({e}), using manual tokenization")
full_prompt = str(prompt)
encoded = tokenizer(
full_prompt, return_tensors="pt", add_special_tokens=True
)
input_ids = encoded["input_ids"].to(device)
except Exception as e:
print("Error applying chat template:", e)
import traceback
traceback.print_exc()
raise
attention_mask = torch.ones_like(input_ids).to(device)
# When disabling Qwen3.5 thinking, append suffix to prompt so model continues with the answer (same as run_vlm).
if add_nothink_assistant_turn:
nothink_tokens = tokenizer.encode(
VLM_QWEN3_5_NOTHINK_SUFFIX, add_special_tokens=False, return_tensors="pt"
)
if nothink_tokens.dim() == 1:
nothink_tokens = nothink_tokens.unsqueeze(0)
nothink_tokens = nothink_tokens.to(device)
input_ids = torch.cat([input_ids, nothink_tokens], dim=1)
attention_mask = torch.cat(
[
attention_mask,
torch.ones(
(attention_mask.shape[0], nothink_tokens.shape[1]),
device=device,
dtype=attention_mask.dtype,
),
],
dim=1,
)
# Map generation config to transformers parameters.
# When use_vlm_safe_generation (VLM model used for LLM tasks), use greedy decoding to avoid
# "probability tensor contains inf/nan or element < 0" errors in torch.multinomial on some setups.
if use_vlm_safe_generation:
generation_kwargs = {
"max_new_tokens": gen_config.max_tokens,
"do_sample": False,
"attention_mask": attention_mask,
}
else:
generation_kwargs = {
"max_new_tokens": gen_config.max_tokens,
"temperature": gen_config.temperature,
"top_p": gen_config.top_p,
"top_k": gen_config.top_k,
"do_sample": True,
"attention_mask": attention_mask,
}
if gen_config.stream:
streamer = (
_LLMGUIStreamer(tokenizer, skip_prompt=True)
if REPORT_LLM_OUTPUTS_TO_GUI
else TextStreamer(tokenizer, skip_prompt=True)
)
else:
streamer = None
# Remove parameters that don't exist in transformers (repetition_penalty is valid for both sampling and greedy)
if hasattr(gen_config, "repeat_penalty") and gen_config.repeat_penalty is not None:
generation_kwargs["repetition_penalty"] = gen_config.repeat_penalty
if PRINT_TRANSFORMERS_USER_PROMPT:
print("Generation kwargs:", generation_kwargs)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
# --- Timed Inference Test ---
print("\nStarting model inference...")
start_time = time.time()
# Use speculative decoding if assistant model is available
try:
if speculative_decoding and assistant_model is not None:
if PRINT_TRANSFORMERS_USER_PROMPT:
print("Using speculative decoding with assistant model")
outputs = model.generate(
input_ids,
assistant_model=assistant_model,
**generation_kwargs,
streamer=streamer,
)
else:
if PRINT_TRANSFORMERS_USER_PROMPT:
print("Generating without speculative decoding")
outputs = model.generate(input_ids, **generation_kwargs, streamer=streamer)
except Exception as e:
error_msg = str(e)
# Check if this is a CUDA compilation error
if (
"sm_120" in error_msg
or "LLVM ERROR" in error_msg
or "Cannot select" in error_msg
):
print("\n" + "=" * 80)
print("CUDA COMPILATION ERROR DETECTED")
print("=" * 80)
print(
"\nThe error is caused by torch.compile() trying to compile CUDA kernels"
)
print(
"with incompatible settings. This is a known issue with certain CUDA/PyTorch"
)
print("combinations.\n")
print(
"SOLUTION: Disable model compilation by setting COMPILE_TRANSFORMERS=False"
)
print("in your config file (config/app_config.env).")
print(
"\nThe model will still work without compilation, just slightly slower."
)
print("=" * 80 + "\n")
raise RuntimeError(
"CUDA compilation error detected. Please set COMPILE_TRANSFORMERS=False "
"in your config file to disable model compilation and avoid this error."
) from e
else:
# Re-raise other errors as-is
raise
end_time = time.time()
# --- Decode and Display Results ---
# Extract only the newly generated tokens (exclude input tokens)
input_length = input_ids.shape[-1]
# Handle different output formats from model.generate()
# model.generate() returns a tensor with shape [batch_size, sequence_length]
# that includes both input and generated tokens
if isinstance(outputs, torch.Tensor):
# If outputs is a tensor, extract the new tokens
if outputs.dim() == 2:
# Shape: [batch_size, sequence_length]
new_tokens = outputs[0, input_length:].clone()
elif outputs.dim() == 1:
# Shape: [sequence_length] (single sequence)
new_tokens = outputs[input_length:].clone()
else:
raise ValueError(f"Unexpected output tensor shape: {outputs.shape}")
else:
# If outputs is a sequence or other format
if hasattr(outputs, "__getitem__"):
new_tokens = (
outputs[0][input_length:]
if len(outputs) > 0
else outputs[input_length:]
)
else:
raise ValueError(f"Unexpected output type: {type(outputs)}")
# Ensure new_tokens is a tensor and on CPU for decoding
if isinstance(new_tokens, torch.Tensor):
new_tokens = new_tokens.cpu().clone()
# Convert to list for decoding (some tokenizers prefer lists)
new_tokens_list = new_tokens.tolist()
else:
new_tokens_list = (
list(new_tokens) if hasattr(new_tokens, "__iter__") else [new_tokens]
)
if PRINT_TRANSFORMERS_USER_PROMPT:
print(f"Input length: {input_length}")
print(f"Output shape: {outputs.shape if hasattr(outputs, 'shape') else 'N/A'}")
print(f"New tokens count: {len(new_tokens_list)}")
print(f"First 20 new token IDs: {new_tokens_list[:20]}")
# Decode the tokens
# Use the token list for decoding (more reliable than tensor)
try:
assistant_reply = tokenizer.decode(
new_tokens_list, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
except Exception as e:
print(f"Warning: Error decoding tokens: {e}")
print(f"New tokens count: {len(new_tokens_list)}")
print(f"New tokens (first 20): {new_tokens_list[:20]}")
# Try alternative decoding methods
try:
# Try with tensor directly
if isinstance(new_tokens, torch.Tensor):
assistant_reply = tokenizer.decode(
new_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
else:
raise e
except Exception as e2:
print(f"Error with tensor decoding: {e2}")
# Last resort: try to decode each token individually to see which ones fail
try:
decoded_parts = []
failed_tokens = []
for i, token_id in enumerate(
new_tokens_list[:200]
): # Limit to first 200 to avoid issues
try:
decoded = tokenizer.decode([token_id], skip_special_tokens=True)
decoded_parts.append(decoded)
except Exception as token_error:
failed_tokens.append((i, token_id, str(token_error)))
decoded_parts.append(f"<TOKEN_ERROR_{token_id}>")
if failed_tokens:
print(
f"Warning: {len(failed_tokens)} tokens failed to decode individually"
)
print(f"First few failed tokens: {failed_tokens[:5]}")
assistant_reply = "".join(decoded_parts)
except Exception as e3:
print(f"Error with individual token decoding: {e3}")
assistant_reply = f"<DECODING_ERROR: {str(e3)}>"
num_input_tokens = input_length
num_generated_tokens = (
len(new_tokens_list) if hasattr(new_tokens_list, "__len__") else 0
)
duration = end_time - start_time
tokens_per_second = num_generated_tokens / duration if duration > 0 else 0
if PRINT_TRANSFORMERS_USER_PROMPT:
print(f"\nDecoded output length: {len(assistant_reply)} characters")
print(f"First 200 chars of output: {assistant_reply[:200]}")
print("\n--- Performance ---")
print(f"Time taken: {duration:.2f} seconds")
print(f"Generated tokens: {num_generated_tokens}")
print(f"Tokens per second: {tokens_per_second:.2f}")
return assistant_reply, num_input_tokens, num_generated_tokens
# Function to send a request and update history
def send_request(
prompt: str,
conversation_history: List[dict],
client: ai.Client | OpenAI,
config: types.GenerateContentConfig,
model_choice: str,
system_prompt: str,
temperature: float,
bedrock_runtime: boto3.Session.client,
model_source: str,
local_model=_pii_model,
tokenizer=_pii_tokenizer,
assistant_model=_pii_assistant_model,
assistant_prefill="",
progress=Progress(track_tqdm=True),
api_url: str = None,
) -> Tuple[str, List[dict]]:
"""Sends a request to a language model and manages the conversation history.
This function constructs the full prompt by appending the new user prompt to the conversation history,
generates a response from the model, and updates the conversation history with the new prompt and response.
It handles different model sources (Gemini, AWS, Local, inference-server) and includes retry logic for API calls.
Args:
prompt (str): The user's input prompt to be sent to the model.
conversation_history (List[dict]): A list of dictionaries representing the ongoing conversation.
Each dictionary should have 'role' and 'parts' keys.
client (ai.Client): The API client object for the chosen model (e.g., Gemini `ai.Client`, or Azure/OpenAI `OpenAI`).
config (types.GenerateContentConfig): Configuration settings for content generation (e.g., Gemini `types.GenerateContentConfig`).
model_choice (str): The specific model identifier to use (e.g., "gemini-pro", "claude-v2").
system_prompt (str): An optional system-level instruction or context for the model.
temperature (float): Controls the randomness of the model's output, with higher values leading to more diverse responses.
bedrock_runtime (boto3.Session.client): The boto3 Bedrock runtime client object for AWS models.
model_source (str): Indicates the source/provider of the model (e.g., "Gemini", "AWS", "Local", "inference-server").
local_model (list, optional): A list containing the local model and its tokenizer (if `model_source` is "Local"). Defaults to [].
tokenizer (object, optional): The tokenizer object for local models. Defaults to None.
assistant_model (object, optional): An optional assistant model used for speculative decoding with local models. Defaults to None.
assistant_prefill (str, optional): A string to pre-fill the assistant's response, useful for certain models like Claude. Defaults to "".
progress (Progress, optional): A progress object for tracking the operation, typically from `tqdm`. Defaults to Progress(track_tqdm=True).
api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
Returns:
Tuple[str, List[dict]]: A tuple containing the model's response text and the updated conversation history.
"""
# Constructing the full prompt from the conversation history
full_prompt = "Conversation history:\n"
num_transformer_input_tokens = 0
num_transformer_generated_tokens = 0
response_text = ""
if not model_choice or model_choice == "":
model_choice = None
for entry in conversation_history:
role = entry[
"role"
].capitalize() # Assuming the history is stored with 'role' and 'parts'
message = " ".join(entry["parts"]) # Combining all parts of the message
full_prompt += f"{role}: {message}\n"
# Adding the new user prompt
full_prompt += f"\nUser: {prompt}"
# Clear any existing progress bars
tqdm._instances.clear()
progress_bar = range(0, number_of_api_retry_attempts)
# Generate the model's response
if "Gemini" in model_source:
for i in progress_bar:
try:
print("Calling Gemini model, attempt", i + 1)
response = client.models.generate_content(
model=model_choice, contents=full_prompt, config=config
)
# print("Successful call to Gemini model.")
break
except Exception as e:
# If fails, try again after X seconds in case there is a throttle limit
print(
"Call to Gemini model failed:",
e,
" Waiting for ",
str(timeout_wait),
"seconds and trying again.",
)
time.sleep(timeout_wait)
if i == number_of_api_retry_attempts:
return (
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
elif "AWS" in model_source:
for i in progress_bar:
try:
# print("Calling AWS Bedrock model, attempt", i + 1)
response = call_aws_bedrock(
prompt,
system_prompt,
temperature,
max_tokens,
model_choice,
bedrock_runtime=bedrock_runtime,
assistant_prefill=assistant_prefill,
)
# print("Successful call to Claude model.")
break
except Exception as e:
# If fails, try again after X seconds in case there is a throttle limit
print(
"Call to Bedrock model failed:",
e,
" Waiting for ",
str(timeout_wait),
"seconds and trying again.",
)
time.sleep(timeout_wait)
if i == number_of_api_retry_attempts:
return (
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
elif "Azure/OpenAI" in model_source:
for i in progress_bar:
try:
print("Calling Azure/OpenAI inference model, attempt", i + 1)
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": prompt,
},
]
response_raw = client.chat.completions.create(
messages=messages,
model=model_choice,
temperature=temperature,
max_completion_tokens=max_tokens,
)
response_text = response_raw.choices[0].message.content
usage = getattr(response_raw, "usage", None)
input_tokens = 0
output_tokens = 0
if usage is not None:
input_tokens = getattr(
usage, "input_tokens", getattr(usage, "prompt_tokens", 0)
)
output_tokens = getattr(
usage, "output_tokens", getattr(usage, "completion_tokens", 0)
)
response = ResponseObject(
text=response_text,
usage_metadata={
"inputTokens": input_tokens,
"outputTokens": output_tokens,
},
)
break
except Exception as e:
print(
"Call to Azure/OpenAI model failed:",
e,
" Waiting for ",
str(timeout_wait),
"seconds and trying again.",
)
time.sleep(timeout_wait)
if i == number_of_api_retry_attempts:
return (
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
elif "Local" in model_source:
# This is the local model. When USE_TRANSFORMERS_VLM_MODEL_AS_LLM and model_choice is the VLM model, use the loaded VLM model/tokenizer.
vlm_model, vlm_tokenizer = None, None
if (
USE_TRANSFORMERS_VLM_MODEL_AS_LLM
and model_choice == SELECTED_LOCAL_TRANSFORMERS_VLM_MODEL
):
try:
from tools.run_vlm import get_loaded_vlm_model_and_tokenizer
vlm_model, vlm_tokenizer = get_loaded_vlm_model_and_tokenizer()
except Exception as e:
print(
f"Could not get VLM model for LLM task (USE_TRANSFORMERS_VLM_MODEL_AS_LLM): {e}"
)
for i in progress_bar:
try:
print("Calling local model, attempt", i + 1)
gen_config = LocalLLMGenerationConfig()
gen_config.update_temp(temperature)
# Call transformers model; use VLM model/tokenizer when USE_TRANSFORMERS_VLM_MODEL_AS_LLM and available
if vlm_model is not None and vlm_tokenizer is not None:
(
response,
num_transformer_input_tokens,
num_transformer_generated_tokens,
) = call_transformers_model(
prompt,
system_prompt,
gen_config,
model=vlm_model,
tokenizer=vlm_tokenizer,
use_vlm_safe_generation=VLM_DEFAULT_DO_SAMPLE,
)
else:
(
response,
num_transformer_input_tokens,
num_transformer_generated_tokens,
) = call_transformers_model(
prompt,
system_prompt,
gen_config,
)
response_text = response
break
except Exception as e:
# If fails, try again after X seconds in case there is a throttle limit
print(
"Call to local model failed:",
e,
" Waiting for ",
str(timeout_wait),
"seconds and trying again.",
)
time.sleep(timeout_wait)
if i == number_of_api_retry_attempts:
return (
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
elif "inference-server" in model_source:
# This is the inference-server API
for i in progress_bar:
try:
print("Calling inference-server API, attempt", i + 1)
if api_url is None:
raise ValueError(
"api_url is required when model_source is 'inference-server'"
)
gen_config = LocalLLMGenerationConfig()
gen_config.update_temp(temperature)
response = call_inference_server_api(
prompt,
system_prompt,
gen_config,
api_url=api_url,
model_name=model_choice,
use_llama_swap=USE_LLAMA_SWAP,
)
break
except Exception as e:
# If fails, try again after X seconds in case there is a throttle limit
print(
"Call to inference-server API failed:",
e,
" Waiting for ",
str(timeout_wait),
"seconds and trying again.",
)
time.sleep(timeout_wait)
if i == number_of_api_retry_attempts:
return (
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
else:
print("Model source not recognised")
return (
ResponseObject(text="", usage_metadata={"RequestId": "FAILED"}),
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
# Update the conversation history with the new prompt and response
conversation_history.append({"role": "user", "parts": [prompt]})
# Check if is a LLama.cpp model response or inference-server response
if isinstance(response, ResponseObject):
response_text = response.text
elif "choices" in response: # LLama.cpp model response or inference-server response
# Check for GPT-OSS thinking models (case-insensitive, handle both hyphen and underscore)
if "gpt-oss" in model_choice.lower() or "gpt_oss" in model_choice.lower():
content = _stringify_openai_message_content(
response["choices"][0]["message"].get("content")
)
# Split on the final channel marker to extract only the final output (not thinking tokens)
parts = content.split("<|start|>assistant<|channel|>final<|message|>")
if len(parts) > 1:
response_text = parts[1]
# Following format may be from llama.cpp inference-server response
elif len(parts) == 1:
parts = content.split("<|end|>")
if len(parts) > 1:
response_text = parts[1]
else:
print(
"Warning: Could not find final channel marker in GPT-OSS response. Using full content."
)
response_text = content
else:
# Fallback: if marker not found, use the full content (may include thinking tokens)
print(
"Warning: Could not find final channel marker in GPT-OSS response. Using full content."
)
response_text = content
else:
response_text = _extract_choice_message_text(response["choices"][0])
elif model_source == "Gemini":
response_text = response.text
else: # Assume transformers model response
# Check for GPT-OSS thinking models (case-insensitive, handle both hyphen and underscore)
if "gpt-oss" in model_choice.lower() or "gpt_oss" in model_choice.lower():
# Split on the final channel marker to extract only the final output (not thinking tokens)
parts = response.split("<|start|>assistant<|channel|>final<|message|>")
if len(parts) > 1:
response_text = parts[1]
else:
# Fallback: if marker not found, use the full content (may include thinking tokens)
print(
"Warning: Could not find final channel marker in GPT-OSS response. Using full content."
)
response_text = response
else:
response_text = response
# Strip <|end|> tags (used by GPT-OSS thinking models to mark end of thinking)
response_text = response_text or ""
response_text = re.sub(r"<\|end\|>", "", response_text)
# Replace multiple spaces with single space
response_text = re.sub(r" {2,}", " ", response_text)
response_text = response_text.strip()
conversation_history.append({"role": "assistant", "parts": [response_text]})
return (
response,
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
)
def process_requests(
prompts: List[str],
system_prompt: str,
conversation_history: List[dict],
whole_conversation: List[str],
whole_conversation_metadata: List[str],
client: ai.Client | OpenAI,
config: types.GenerateContentConfig,
model_choice: str,
temperature: float,
bedrock_runtime: boto3.Session.client,
model_source: str,
batch_no: int = 1,
local_model=_pii_model,
tokenizer=_pii_tokenizer,
assistant_model=_pii_assistant_model,
master: bool = False,
assistant_prefill="",
api_url: str = None,
) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
"""
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
Args:
prompts (List[str]): A list of prompts to be processed.
system_prompt (str): The system prompt.
conversation_history (List[dict]): The history of the conversation.
whole_conversation (List[str]): The complete conversation including prompts and responses.
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
client (object): The client to use for processing the prompts, from either Gemini or OpenAI client.
config (dict): Configuration for the model.
model_choice (str): The choice of model to use.
temperature (float): The temperature parameter for the model.
model_source (str): Source of the model, whether local, AWS, Gemini, or inference-server
batch_no (int): Batch number of the large language model request.
local_model: Local gguf model (if loaded)
master (bool): Is this request for the master table.
assistant_prefill (str, optional): Is there a prefill for the assistant response. Currently only working for AWS model calls
bedrock_runtime: The client object for boto3 Bedrock runtime
api_url (str, optional): The API URL for inference-server calls. Required when model_source is 'inference-server'.
Returns:
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
"""
responses = list()
# Clear any existing progress bars
tqdm._instances.clear()
for prompt in prompts:
(
response,
conversation_history,
response_text,
num_transformer_input_tokens,
num_transformer_generated_tokens,
) = send_request(
prompt,
conversation_history,
client=client,
config=config,
model_choice=model_choice,
system_prompt=system_prompt,
temperature=temperature,
local_model=local_model,
tokenizer=tokenizer,
assistant_model=assistant_model,
assistant_prefill=assistant_prefill,
bedrock_runtime=bedrock_runtime,
model_source=model_source,
api_url=api_url,
)
responses.append(response)
whole_conversation.append(system_prompt)
whole_conversation.append(prompt)
whole_conversation.append(response_text)
whole_conversation_metadata.append(f"Batch {batch_no}:")
try:
if "AWS" in model_source:
output_tokens = response.usage_metadata.get("outputTokens", 0)
input_tokens = response.usage_metadata.get("inputTokens", 0)
elif "Gemini" in model_source:
output_tokens = response.usage_metadata.candidates_token_count
input_tokens = response.usage_metadata.prompt_token_count
elif "Azure/OpenAI" in model_source:
input_tokens = response.usage_metadata.get("inputTokens", 0)
output_tokens = response.usage_metadata.get("outputTokens", 0)
elif "Local" in model_source:
input_tokens = num_transformer_input_tokens
output_tokens = num_transformer_generated_tokens
elif "inference-server" in model_source:
# inference-server returns the same format as llama-cpp
output_tokens = response["usage"].get("completion_tokens", 0)
input_tokens = response["usage"].get("prompt_tokens", 0)
else:
input_tokens = 0
output_tokens = 0
whole_conversation_metadata.append(
"input_tokens: "
+ str(input_tokens)
+ " output_tokens: "
+ str(output_tokens)
)
except KeyError as e:
print(f"Key error: {e} - Check the structure of response.usage_metadata")
return (
responses,
conversation_history,
whole_conversation,
whole_conversation_metadata,
response_text,
)
def call_inference_server_api(
formatted_string: str,
system_prompt: str,
gen_config: LocalLLMGenerationConfig,
api_url: str = "http://localhost:8080",
model_name: str = None,
use_llama_swap: bool = USE_LLAMA_SWAP,
):
"""
Calls a inference-server API endpoint with a formatted user message and system prompt,
using generation parameters from the LocalLLMGenerationConfig object.
This function provides the same interface as call_transformers_model but calls
a remote inference-server instance instead of a local model.
Args:
formatted_string (str): The formatted input text for the user's message.
system_prompt (str): The system-level instructions for the model.
gen_config (LocalLLMGenerationConfig): An object containing generation parameters.
api_url (str): The base URL of the inference-server API (default: "http://localhost:8080").
model_name (str): Optional model name to use. If None, uses the default model.
use_llama_swap (bool): Whether to use llama-swap for the model.
Returns:
dict: Response in the same format as the inference-server chat completions API
Example:
# Create generation config
gen_config = LocalLLMGenerationConfig(temperature=0.7, max_tokens=100)
# Call the API
response = call_inference_server_api(
formatted_string="Hello, how are you?",
system_prompt="You are a helpful assistant.",
gen_config=gen_config,
api_url="http://localhost:8080"
)
# Extract the response text
response_text = response['choices'][0]['message']['content']
Integration Example:
# To use inference-server instead of local model:
# 1. Set model_source to "inference-server"
# 2. Provide api_url parameter
# 3. Call your existing functions as normal
responses, conversation_history, whole_conversation, whole_conversation_metadata, response_text = call_llm_with_markdown_table_checks(
batch_prompts=["Your prompt here"],
system_prompt="Your system prompt",
conversation_history=[],
whole_conversation=[],
whole_conversation_metadata=[],
client=None, # Not used for inference-server
client_config=None, # Not used for inference-server
model_choice="your-model-name", # Model name on the server
temperature=0.7,
reported_batch_no=1,
local_model=None, # Not used for inference-server
tokenizer=None, # Not used for inference-server
bedrock_runtime=None, # Not used for inference-server
model_source="inference-server",
MAX_OUTPUT_VALIDATION_ATTEMPTS=3,
api_url="http://localhost:8080"
)
"""
# Extract parameters from the gen_config object
temperature = gen_config.temperature
top_k = gen_config.top_k
top_p = gen_config.top_p
repeat_penalty = gen_config.repeat_penalty
seed = gen_config.seed
max_tokens = gen_config.max_tokens
stream = gen_config.stream
# Prepare the request payload
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": formatted_string},
]
payload = {
"messages": messages,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
"seed": seed,
"max_tokens": max_tokens,
"stream": stream,
"stop": LLM_STOP_STRINGS if LLM_STOP_STRINGS else [],
}
# Include model in payload when set (vLLM/OpenAI-compatible servers; llama-swap or not).
if model_name or model_name != "":
payload["model"] = model_name
# Match VLM path: Qwen3 / Qwen3.5 on vLLM may stream only "thinking" unless disabled.
if INFERENCE_SERVER_DISABLE_THINKING:
payload["chat_template_kwargs"] = {"enable_thinking": False}
# Determine the endpoint based on streaming preference
if stream:
endpoint = f"{api_url}/v1/chat/completions"
else:
endpoint = f"{api_url}/v1/chat/completions"
try:
if stream:
# Handle streaming response
response = requests.post(
endpoint,
json=payload,
headers={"Content-Type": "application/json"},
stream=True,
timeout=timeout_wait,
)
response.raise_for_status()
final_tokens = []
output_tokens = 0
line_buffer = ""
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
data = line[6:] # Remove 'data: ' prefix
if data.strip() == "[DONE]":
if REPORT_LLM_OUTPUTS_TO_GUI and line_buffer.strip():
_report_llm_output_to_gui(line_buffer)
break
try:
chunk = json.loads(data)
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0].get("delta", {})
token = delta.get("content")
token = _stringify_openai_message_content(token)
if not token:
for alt in (
"reasoning_content",
"reasoning",
):
t = delta.get(alt)
if isinstance(t, str) and t:
token = t
break
if token:
print(token, end="", flush=True)
final_tokens.append(token)
output_tokens += 1
if REPORT_LLM_OUTPUTS_TO_GUI:
line_buffer += token
if "\n" in token:
parts = line_buffer.split("\n")
for complete_line in parts[:-1]:
if complete_line.strip():
_report_llm_output_to_gui(
complete_line
)
line_buffer = parts[-1] if parts else ""
except json.JSONDecodeError:
continue
if REPORT_LLM_OUTPUTS_TO_GUI and line_buffer.strip():
_report_llm_output_to_gui(line_buffer)
print() # newline after stream finishes
text = "".join(final_tokens)
# Estimate input tokens (rough approximation)
input_tokens = len((system_prompt + "\n" + formatted_string).split())
return {
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": text},
}
],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
else:
# Handle non-streaming response
response = requests.post(
endpoint,
json=payload,
headers={"Content-Type": "application/json"},
timeout=timeout_wait,
)
response.raise_for_status()
result = response.json()
# Ensure the response has the expected format
if "choices" not in result:
raise ValueError("Invalid response format from inference-server")
return result
except requests.exceptions.RequestException as e:
raise ConnectionError(
f"Failed to connect to inference-server at {api_url}: {str(e)}"
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON response from inference-server: {str(e)}")
except Exception as e:
raise RuntimeError(f"Error calling inference-server API: {str(e)}")
###
# LLM FUNCTIONS
###
def construct_gemini_generative_model(
in_api_key: str,
temperature: float,
model_choice: str,
system_prompt: str,
max_tokens: int,
random_seed=seed,
) -> Tuple[object, dict]:
"""
Constructs a GenerativeModel for Gemini API calls.
...
"""
# Construct a GenerativeModel
try:
if in_api_key:
# print("Getting API key from textbox")
api_key = in_api_key
client = ai.Client(api_key=api_key)
elif "GOOGLE_API_KEY" in os.environ:
# print("Searching for API key in environmental variables")
api_key = os.environ["GOOGLE_API_KEY"]
client = ai.Client(api_key=api_key)
else:
print("No Gemini API key found")
raise Warning("No Gemini API key found.")
except Exception as e:
print("Error constructing Gemini generative model:", e)
raise Warning("Error constructing Gemini generative model:", e)
config = types.GenerateContentConfig(
temperature=temperature, max_output_tokens=max_tokens, seed=random_seed
)
return client, config
def construct_azure_client(in_api_key: str, endpoint: str) -> Tuple[object, dict]:
"""
Constructs an OpenAI client for Azure/OpenAI AI Inference.
"""
try:
key = None
if in_api_key:
key = in_api_key
elif os.environ.get("AZURE_OPENAI_API_KEY"):
key = os.environ["AZURE_OPENAI_API_KEY"]
if not key:
raise Warning("No Azure/OpenAI API key found.")
if not endpoint:
endpoint = os.environ.get("AZURE_OPENAI_INFERENCE_ENDPOINT", "")
if not endpoint:
# Assume using OpenAI API
client = OpenAI(
api_key=key,
)
else:
# Use the provided endpoint
client = OpenAI(
api_key=key,
base_url=f"{endpoint}",
)
return client, dict()
except Exception as e:
print("Error constructing Azure/OpenAI client:", e)
raise
def call_aws_bedrock(
prompt: str,
system_prompt: str,
temperature: float,
max_tokens: int,
model_choice: str,
bedrock_runtime: boto3.Session.client,
assistant_prefill: str = "",
max_retries: int = 5,
retry_delay_seconds: float = 2.0,
) -> ResponseObject:
"""
This function sends a request to AWS Bedrock with the following parameters:
- prompt: The user's input prompt to be processed by the model.
- system_prompt: A system-defined prompt that provides context or instructions for the model.
- temperature: A value that controls the randomness of the model's output, with higher values resulting in more diverse responses.
- max_tokens: The maximum number of tokens (words or characters) in the model's response.
- model_choice: The specific model to use for processing the request.
- bedrock_runtime: The client object for boto3 Bedrock runtime
- assistant_prefill: A string indicating the text that the response should start with.
- max_retries: Maximum number of retry attempts on failure (default 5).
- retry_delay_seconds: Delay in seconds between retries (default 2.0).
The function constructs the request configuration, invokes the model, extracts the response text, and returns a ResponseObject containing the text and metadata.
"""
inference_config = {
"maxTokens": max_tokens,
"temperature": temperature,
}
# Using an assistant prefill only works for Anthropic models.
if assistant_prefill and "anthropic" in model_choice:
assistant_prefill_added = True
messages = [
{
"role": "user",
"content": [
{"text": prompt},
],
},
{
"role": "assistant",
# Pre-filling with '|'
"content": [{"text": assistant_prefill}],
},
]
else:
assistant_prefill_added = False
messages = [
{
"role": "user",
"content": [
{"text": prompt},
],
}
]
system_prompt_list = [{"text": system_prompt}]
last_error = None
for attempt in range(1, max_retries + 1):
try:
# The converse API call.
api_response = bedrock_runtime.converse(
modelId=model_choice,
messages=messages,
system=system_prompt_list,
inferenceConfig=inference_config,
)
output_message = api_response["output"]["message"]
if "reasoningContent" in output_message["content"][0]:
# Extract the reasoning text
output_message["content"][0]["reasoningContent"]["reasoningText"][
"text"
]
# Extract the output text
if assistant_prefill_added:
text = assistant_prefill + output_message["content"][1]["text"]
else:
text = output_message["content"][1]["text"]
else:
if assistant_prefill_added:
text = assistant_prefill + output_message["content"][0]["text"]
else:
text = output_message["content"][0]["text"]
# The usage statistics are neatly provided in the 'usage' key.
usage = api_response["usage"]
# The full API response metadata is in 'ResponseMetadata' if you still need it.
api_response["ResponseMetadata"]
# Create ResponseObject with the cleanly extracted data.
response = ResponseObject(text=text, usage_metadata=usage)
return response
except Exception as e:
last_error = e
if attempt < max_retries:
print(
f"Bedrock converse API attempt {attempt}/{max_retries} failed: {e}. "
f"Retrying in {retry_delay_seconds}s..."
)
time.sleep(retry_delay_seconds)
else:
raise RuntimeError(
f"Failed to call Bedrock API after {max_retries} attempts: {str(last_error)}"
) from last_error
def calculate_tokens_from_metadata(
metadata_string: str, model_choice: str, model_name_map: dict
):
"""
Calculate the number of input and output tokens for given queries based on metadata strings.
Args:
metadata_string (str): A string containing all relevant metadata from the string.
model_choice (str): A string describing the model name
model_name_map (dict): A dictionary mapping model name to source
"""
# Regex to find the numbers following the keys in the "Query summary metadata" section
# This ensures we get the final, aggregated totals for the whole query.
input_regex = r"input_tokens: (\d+)"
output_regex = r"output_tokens: (\d+)"
# re.findall returns a list of all matching strings (the captured groups).
input_token_strings = re.findall(input_regex, metadata_string)
output_token_strings = re.findall(output_regex, metadata_string)
# Convert the lists of strings to lists of integers and sum them up
total_input_tokens = sum([int(token) for token in input_token_strings])
total_output_tokens = sum([int(token) for token in output_token_strings])
number_of_calls = len(input_token_strings)
print(f"Found {number_of_calls} LLM call entries in metadata.")
print("-" * 20)
print(f"Total Input Tokens: {total_input_tokens}")
print(f"Total Output Tokens: {total_output_tokens}")
return total_input_tokens, total_output_tokens, number_of_calls