import os
import re
import gc
import time
from datetime import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import InferenceClient
from src.config import SYSTEM_PROMPT, MODEL_CONFIGS
from src.tools import web_search, scrape_url, format_search_results_for_prompt
# Conditional Zero-GPU Spaces import
try:
import spaces
HAS_SPACES = True
gpu_decorator = spaces.GPU
except ImportError:
HAS_SPACES = False
# Dummy decorator if not on HF Zero-GPU
def gpu_decorator(f):
return f
# Global Model Cache variables
_current_model = None
_current_tokenizer = None
_current_repo_id = None
def unload_model():
"""Unloads the currently cached model and tokenizer to free RAM/GPU memory."""
global _current_model, _current_tokenizer, _current_repo_id
if _current_model is not None:
print(f"Unloading model: {_current_repo_id} to free memory...")
del _current_model
del _current_tokenizer
_current_model = None
_current_tokenizer = None
_current_repo_id = None
# Force garbage collection and CUDA cache clearing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
time.sleep(1)
def get_local_model(repo_id: str, hf_token: str = None):
"""
Retrieves the local tokenizer and model, loading them from Hugging Face
cache if not already loaded in the memory cache.
"""
global _current_model, _current_tokenizer, _current_repo_id
if _current_repo_id == repo_id and _current_model is not None:
return _current_model, _current_tokenizer
# Unload previous model to avoid out-of-memory errors
unload_model()
print(f"Loading model: {repo_id}...")
token = hf_token or os.environ.get("HF_TOKEN")
# Determine the device mapping (GPU if available, else CPU)
if torch.cuda.is_available():
device_map = "auto"
torch_dtype = torch.float16
else:
device_map = "cpu"
# On CPU, float32 is most stable, bfloat16 can be used if CPU supports it
torch_dtype = torch.float32
try:
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
repo_id,
device_map=device_map,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
token=token
)
except Exception as e:
error_msg = str(e)
if "gated repo" in error_msg.lower() or "401" in error_msg or "unauthorized" in error_msg.lower() or "gatedrepoerror" in error_msg.lower():
raise ValueError(
f"❌ Hugging Face Access Error: The model you selected (`{repo_id}`) is a Gated Model.\n\n"
f"To access this model, please:\n"
f"1. Accept the licensing agreement on the model page: [huggingface.co/{repo_id}](https://huggingface.co/{repo_id})\n"
f"2. Provide your Hugging Face API Read Token in the *Advanced Parameters* input in the sidebar, or set it as a Space Secret named `HF_TOKEN` in your Hugging Face Space settings."
)
else:
raise ValueError(f"❌ Failed to load model `{repo_id}`: {e}")
_current_model = model
_current_tokenizer = tokenizer
_current_repo_id = repo_id
print(f"Successfully loaded {repo_id} into memory.")
return model, tokenizer
def sample_next_token(logits, temperature: float, top_p: float):
"""Samples the next token from logits using temperature and top-p (nucleus) filtering."""
# Apply temperature
if temperature > 0.0 and temperature != 1.0:
logits = logits / temperature
# Apply top-p (nucleus) filtering
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Keep at least the first token (prevent empty sets)
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
sorted_logits[sorted_indices_to_remove] = -float("Inf")
logits = torch.gather(sorted_logits, -1, sorted_indices.argsort(-1))
# Sample or Argmax
if temperature > 0.0:
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token
def generate_local_inference_cpu(prompt_text: str, repo_id: str, max_new_tokens: int, temperature: float, top_p: float, hf_token: str = None):
"""
Executes local text generation using a pure-python KV-cache loop.
This avoids background threads entirely, making it 100% compatible
with both standard CPU spaces and Hugging Face Zero-GPU environments.
"""
model, tokenizer = get_local_model(repo_id, hf_token)
device = next(model.parameters()).device
# Tokenize prompt
input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
# Check stopping criteria (all special tokens like <|endoftext|>, <|im_end|>, etc.)
stop_tokens = tokenizer.all_special_ids
past_key_values = None
generated_ids = []
# Disable gradient computation for efficiency
with torch.no_grad():
# First step (process the whole prompt)
outputs = model(input_ids, use_cache=True)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
next_token = sample_next_token(next_token_logits, temperature, top_p)
next_token_id = next_token.item()
if next_token_id in stop_tokens:
return
generated_ids.append(next_token_id)
yield tokenizer.decode(generated_ids, skip_special_tokens=True)
# Loop steps (generate one token at a time passing KV cache)
for _ in range(max_new_tokens - 1):
outputs = model(next_token, past_key_values=past_key_values, use_cache=True)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
next_token = sample_next_token(next_token_logits, temperature, top_p)
next_token_id = next_token.item()
if next_token_id in stop_tokens:
break
generated_ids.append(next_token_id)
yield tokenizer.decode(generated_ids, skip_special_tokens=True)
# GPU-accelerated version wrapped with Hugging Face Zero-GPU decorator
generate_local_inference_gpu = gpu_decorator(generate_local_inference_cpu)
def run_serverless_api_inference(messages: list, repo_id: str, max_new_tokens: int, temperature: float, top_p: float, hf_token: str = None):
"""
Runs text generation via HF Serverless Inference API client.
Streams tokens in real time.
"""
# Retrieve token from environment variables if not provided explicitly
token = hf_token or os.environ.get("HF_TOKEN")
# Initialize Client
client = InferenceClient(model=repo_id, token=token)
generated_text = ""
try:
response_stream = client.chat_completion(
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stream=True
)
for chunk in response_stream:
content = chunk.choices[0].delta.content
if content:
generated_text += content
yield generated_text
except Exception as e:
error_msg = f"Serverless API Error: {str(e)}\n\n"
if not token:
error_msg += "💡 Tip: Many models require a valid Hugging Face Token for serverless inference. Please enter your HF Token in the sidebar panel."
yield error_msg
def build_prompt_with_history(messages: list, system_prompt: str, tokenizer=None) -> str:
"""
Formats the conversation history using standard chat templates.
"""
formatted_messages = [{"role": "system", "content": system_prompt}] + messages
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
try:
return tokenizer.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)
except Exception:
pass
# Fallback to general formatting if template is unavailable
prompt_str = ""
for msg in formatted_messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt_str += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
prompt_str += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
prompt_str += f"<|im_start|>assistant\n{content}<|im_end|>\n"
prompt_str += "<|im_start|>assistant\n"
return prompt_str
def format_thinking_tags(text: str) -> str:
"""
Replaces model tags with clean, modern HTML Details panels
for premium rendering in the Gradio chat viewport.
"""
if not isinstance(text, str):
if isinstance(text, (list, tuple)):
text = " ".join(str(x) for x in text)
else:
text = str(text) if text is not None else ""
if "" in text:
parts = text.split("", 1)
before_thinking = parts[0]
rest = parts[1]
if "" in rest:
thinking_parts = rest.split("", 1)
thinking_content = thinking_parts[0]
after_thinking = thinking_parts[1]
return f"{before_thinking}Thought Process
\n\n{thinking_content.strip()}\n\n \n\n{after_thinking}"
else:
# Thinking block is still generating, render it open
return f"{before_thinking}Thinking Process...
\n\n{rest.strip()}\n\n "
return text
def extract_artifacts(text: str) -> list:
"""
Extracts all blocks (including in-progress ones)
from the streaming text.
"""
if not isinstance(text, str):
if isinstance(text, (list, tuple)):
text = " ".join(str(x) for x in text)
else:
text = str(text) if text is not None else ""
artifacts = []
# Match tags: content or open tags at the end of text
pattern = r'(.*?)(?:|$)'
matches = re.finditer(pattern, text, re.DOTALL)
for m in matches:
title = m.group(1) or "Untitled"
type_ = m.group(2) or "code"
lang = m.group(3) or "plaintext"
content = m.group(4)
artifacts.append({
"title": title,
"type": type_,
"language": lang,
"content": content.strip()
})
return artifacts
def clean_chatbot_response(text: str) -> str:
"""
Replaces blocks in the response with a clean visual badge
so the raw code doesn't clutter the main chat viewport.
"""
if not isinstance(text, str):
if isinstance(text, (list, tuple)):
text = " ".join(str(x) for x in text)
else:
text = str(text) if text is not None else ""
pattern = r'.*?(?:|$)'
def replace_with_badge(match):
title = match.group(1) or "Untitled Artifact"
type_ = match.group(2) or "code"
return f"\n\n> ⚙️ **Artifact Generated:** *{title}* ({type_}) — *Rendered in the right-hand panel* ↗️\n\n"
return re.sub(pattern, replace_with_badge, text, flags=re.DOTALL)
def execute_chat(
message: str,
history: list,
mode: str,
model_name: str,
system_prompt_preset: str,
max_new_tokens: int,
temperature: float,
top_p: float,
enable_search: bool,
hf_token: str
):
"""
Orchestrates the chat request, performs search if toggled, builds the history,
and runs inference on the selected backend mode (Local CPU, Zero-GPU, or API).
"""
# 1. Look up the repo_id from configs
repo_id = None
for item in MODEL_CONFIGS.get(mode, []):
if item["name"] == model_name:
repo_id = item["repo_id"]
break
if not repo_id:
yield history + [[message, "Configuration Error: Selected model details not found."]], ""
return
# 2. Handle web search if enabled
search_context = ""
status_update = ""
if enable_search:
status_update = f"🔍 Searching web for: '{message}'...\n"
yield history + [[message, status_update]], ""
results = web_search(message, max_results=3)
if results:
status_update += f"📄 Scraped {len(results)} relevant web sources. Integrating context...\n"
yield history + [[message, status_update]], ""
# Scrape details from the top result to enrich context
top_url = results[0]["url"]
scraped_content = scrape_url(top_url, max_chars=3000)
# Format combined search results
search_context = format_search_results_for_prompt(message, results)
search_context += f"\nDetailed body scraped from source [1] ({top_url}):\n{scraped_content}\n---\n"
else:
status_update += "❌ Web search returned no results. Proceeding with model knowledge...\n"
yield history + [[message, status_update]], ""
time.sleep(1)
# 3. Compile history into standard Gradio message formats
chat_messages = []
for user_msg, bot_msg in history:
# If the bot response has status logs from web search, strip them so LLM doesn't read them as its own words
clean_bot_msg = bot_msg
if "🔍 Searching web" in bot_msg:
# Split and get the text after the final status separator if it exists
parts = bot_msg.split("---\n")
if len(parts) > 1:
clean_bot_msg = parts[-1]
else:
# Fallback if structure is different
clean_bot_msg = bot_msg.split("\n")[-1]
chat_messages.append({"role": "user", "content": user_msg})
chat_messages.append({"role": "assistant", "content": clean_bot_msg})
# Prepare active prompt contents
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
compiled_system_prompt = system_prompt_preset.replace("{datetime}", current_time)
# Prepend search context to user query if found
if search_context:
user_query_content = f"{search_context}User Query: {message}"
else:
user_query_content = message
chat_messages.append({"role": "user", "content": user_query_content})
# 4. Invoke inference backend
if mode == "HF Serverless API (Zero Overhead)":
# Stream response from API
api_stream = run_serverless_api_inference(
messages=chat_messages,
repo_id=repo_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
hf_token=hf_token
)
for partial_text in api_stream:
formatted_text = format_thinking_tags(partial_text)
artifacts = extract_artifacts(formatted_text)
clean_text = clean_chatbot_response(formatted_text)
full_response = status_update + clean_text if status_update else clean_text
yield history + [[message, full_response]], artifacts
else:
# Local CPU or Zero-GPU mode
# Load local tokenizer (temporarily to build prompt or load model)
# Note: loading tokenizer is fast and lightweight
token = hf_token or os.environ.get("HF_TOKEN")
try:
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=token)
except Exception:
tokenizer = None
prompt_text = build_prompt_with_history(chat_messages, compiled_system_prompt, tokenizer)
# Free up variables
del tokenizer
# Conditionally invoke GPU or CPU engine based on selected UI mode
if mode == "Zero-GPU (Accelerated)":
local_stream = generate_local_inference_gpu(
prompt_text=prompt_text,
repo_id=repo_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
hf_token=token
)
else:
local_stream = generate_local_inference_cpu(
prompt_text=prompt_text,
repo_id=repo_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
hf_token=token
)
try:
for partial_text in local_stream:
formatted_text = format_thinking_tags(partial_text)
artifacts = extract_artifacts(formatted_text)
clean_text = clean_chatbot_response(formatted_text)
full_response = status_update + clean_text if status_update else clean_text
yield history + [[message, full_response]], artifacts
except Exception as e:
error_message = f"\n\n### ⚠️ Inference Failure\n{e}"
yield history + [[message, status_update + error_message if status_update else error_message]], []