import os import logging from huggingface_hub import hf_hub_download from llama_cpp import Llama logger = logging.getLogger(__name__) llm = None def load_model(): """Download (if needed) and load the GGUF model into memory.""" global llm # We will use Gemma 4 E4B (Effective 4B) in 4-bit GGUF by default if not set # The user can still set a different GGUF model via env variables if they want repo_id = os.environ.get("MODEL_ID", "bartowski/google_gemma-4-E4B-it-GGUF") filename = os.environ.get("MODEL_FILENAME", "google_gemma-4-E4B-it-Q4_K_M.gguf") logger.info(f"Loading model - MODEL_ID: {repo_id}, MODEL_FILENAME: {filename}") cache_dir = "./model_cache" # --- Cache Cleanup Logic --- import shutil os.makedirs(cache_dir, exist_ok=True) model_info_file = os.path.join(cache_dir, ".current_model") current_model_str = f"{repo_id}:{filename}" if os.path.exists(model_info_file): with open(model_info_file, "r") as f: last_model_str = f.read().strip() if last_model_str != current_model_str: logger.info(f"Model changed from {last_model_str} to {current_model_str}. Clearing old cache to save space.") shutil.rmtree(cache_dir) os.makedirs(cache_dir, exist_ok=True) # Save current model info with open(model_info_file, "w") as f: f.write(current_model_str) # --------------------------- logger.info(f"Checking for model {repo_id} ({filename}) in {cache_dir}...") try: # Download the model from HuggingFace Hub (this uses the cache automatically) model_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=cache_dir ) logger.info(f"Loading model into memory from {model_path}...") # n_ctx is the total context window (input tokens + output tokens). # 8192 is safe for a 4B GGUF model and prevents overflow errors on long # conversations or web-augmented queries that would fail at 4096. # HARDCODE n_threads to 2. HF Spaces free tier only gives 2 vCPUs. # os.cpu_count() returns the host machine's cores (often 64+) which causes extreme thread thrashing and destroys performance. llm = Llama( model_path=model_path, n_ctx=8192, n_threads=2, flash_attn=True, verbose=False ) logger.info(f"Successfully loaded {filename}") except Exception as e: logger.error(f"Error loading GGUF model: {e}") llm = None def generate_response_stream(history: list, query: str, max_new_tokens: int = 500): """ Generate a response using Llama's native chat completion, yielding chunks. history format: [{"role": "user", "content": "msg"}, {"role": "assistant", "content": "msg"}] """ global llm if not llm: logger.warning("Generate response called but model is not loaded. Returning placeholder.") yield "I am a placeholder AI assistant. Please ensure the model downloaded correctly." return # Append the new query to the history messages = history.copy() # Prepend system prompt if history is empty (optional but recommended for Llama 3) if not messages or messages[0].get("role") != "system": messages.insert(0, {"role": "system", "content": "You are a helpful AI assistant."}) messages.append({"role": "user", "content": query}) # Retry loop: if the prompt is still too long for the context window after # keeping the system message, drop the oldest user/assistant turn pair and # try again. Stop retrying once only the system message + current query remain. while True: try: response = llm.create_chat_completion( messages=messages, max_tokens=max_new_tokens, temperature=0.7, stream=True ) for chunk in response: delta = chunk["choices"][0].get("delta", {}) if "content" in delta: yield delta["content"] return except Exception as e: err_str = str(e).lower() if "exceed" in err_str and "context" in err_str: # Find the oldest non-system, non-latest-user message pair to drop # messages layout: [system?, ...history..., latest_user] # History pairs are at indices 1..-2 (excluding system and last user msg) start = 1 if messages[0].get("role") == "system" else 0 # Need at least one history turn (2 messages) to trim if len(messages) - start > 2: logger.warning( f"Context window overflow ({e}). " f"Dropping oldest history turn and retrying. " f"Messages remaining: {len(messages) - 2}" ) # Drop the oldest user+assistant pair (2 messages after system) messages = messages[:start] + messages[start + 2:] continue # Nothing left to trim — surface a clean error logger.error(f"Context window overflow even with minimal history: {e}") yield "I'm sorry, your query is too long for me to process. Please try a shorter message or start a new conversation." return else: logger.error(f"Error generating response: {e}") yield f"Error generating response: {e}" return