private-ai-backend / model.py
adeebjamal's picture
increased n_ctx
6cf1dc4
import os
import logging
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
logger = logging.getLogger(__name__)
llm = None
def load_model():
"""Download (if needed) and load the GGUF model into memory."""
global llm
# We will use Gemma 4 E4B (Effective 4B) in 4-bit GGUF by default if not set
# The user can still set a different GGUF model via env variables if they want
repo_id = os.environ.get("MODEL_ID", "bartowski/google_gemma-4-E4B-it-GGUF")
filename = os.environ.get("MODEL_FILENAME", "google_gemma-4-E4B-it-Q4_K_M.gguf")
logger.info(f"Loading model - MODEL_ID: {repo_id}, MODEL_FILENAME: {filename}")
cache_dir = "./model_cache"
# --- Cache Cleanup Logic ---
import shutil
os.makedirs(cache_dir, exist_ok=True)
model_info_file = os.path.join(cache_dir, ".current_model")
current_model_str = f"{repo_id}:{filename}"
if os.path.exists(model_info_file):
with open(model_info_file, "r") as f:
last_model_str = f.read().strip()
if last_model_str != current_model_str:
logger.info(f"Model changed from {last_model_str} to {current_model_str}. Clearing old cache to save space.")
shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
# Save current model info
with open(model_info_file, "w") as f:
f.write(current_model_str)
# ---------------------------
logger.info(f"Checking for model {repo_id} ({filename}) in {cache_dir}...")
try:
# Download the model from HuggingFace Hub (this uses the cache automatically)
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir
)
logger.info(f"Loading model into memory from {model_path}...")
# n_ctx is the total context window (input tokens + output tokens).
# 8192 is safe for a 4B GGUF model and prevents overflow errors on long
# conversations or web-augmented queries that would fail at 4096.
# HARDCODE n_threads to 2. HF Spaces free tier only gives 2 vCPUs.
# os.cpu_count() returns the host machine's cores (often 64+) which causes extreme thread thrashing and destroys performance.
llm = Llama(
model_path=model_path,
n_ctx=8192,
n_threads=2,
flash_attn=True,
verbose=False
)
logger.info(f"Successfully loaded {filename}")
except Exception as e:
logger.error(f"Error loading GGUF model: {e}")
llm = None
def generate_response_stream(history: list, query: str, max_new_tokens: int = 500):
"""
Generate a response using Llama's native chat completion, yielding chunks.
history format: [{"role": "user", "content": "msg"}, {"role": "assistant", "content": "msg"}]
"""
global llm
if not llm:
logger.warning("Generate response called but model is not loaded. Returning placeholder.")
yield "I am a placeholder AI assistant. Please ensure the model downloaded correctly."
return
# Append the new query to the history
messages = history.copy()
# Prepend system prompt if history is empty (optional but recommended for Llama 3)
if not messages or messages[0].get("role") != "system":
messages.insert(0, {"role": "system", "content": "You are a helpful AI assistant."})
messages.append({"role": "user", "content": query})
# Retry loop: if the prompt is still too long for the context window after
# keeping the system message, drop the oldest user/assistant turn pair and
# try again. Stop retrying once only the system message + current query remain.
while True:
try:
response = llm.create_chat_completion(
messages=messages,
max_tokens=max_new_tokens,
temperature=0.7,
stream=True
)
for chunk in response:
delta = chunk["choices"][0].get("delta", {})
if "content" in delta:
yield delta["content"]
return
except Exception as e:
err_str = str(e).lower()
if "exceed" in err_str and "context" in err_str:
# Find the oldest non-system, non-latest-user message pair to drop
# messages layout: [system?, ...history..., latest_user]
# History pairs are at indices 1..-2 (excluding system and last user msg)
start = 1 if messages[0].get("role") == "system" else 0
# Need at least one history turn (2 messages) to trim
if len(messages) - start > 2:
logger.warning(
f"Context window overflow ({e}). "
f"Dropping oldest history turn and retrying. "
f"Messages remaining: {len(messages) - 2}"
)
# Drop the oldest user+assistant pair (2 messages after system)
messages = messages[:start] + messages[start + 2:]
continue
# Nothing left to trim — surface a clean error
logger.error(f"Context window overflow even with minimal history: {e}")
yield "I'm sorry, your query is too long for me to process. Please try a shorter message or start a new conversation."
return
else:
logger.error(f"Error generating response: {e}")
yield f"Error generating response: {e}"
return