Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| logger = logging.getLogger(__name__) | |
| llm = None | |
| def load_model(): | |
| """Download (if needed) and load the GGUF model into memory.""" | |
| global llm | |
| # We will use Gemma 4 E4B (Effective 4B) in 4-bit GGUF by default if not set | |
| # The user can still set a different GGUF model via env variables if they want | |
| repo_id = os.environ.get("MODEL_ID", "bartowski/google_gemma-4-E4B-it-GGUF") | |
| filename = os.environ.get("MODEL_FILENAME", "google_gemma-4-E4B-it-Q4_K_M.gguf") | |
| logger.info(f"Loading model - MODEL_ID: {repo_id}, MODEL_FILENAME: {filename}") | |
| cache_dir = "./model_cache" | |
| # --- Cache Cleanup Logic --- | |
| import shutil | |
| os.makedirs(cache_dir, exist_ok=True) | |
| model_info_file = os.path.join(cache_dir, ".current_model") | |
| current_model_str = f"{repo_id}:{filename}" | |
| if os.path.exists(model_info_file): | |
| with open(model_info_file, "r") as f: | |
| last_model_str = f.read().strip() | |
| if last_model_str != current_model_str: | |
| logger.info(f"Model changed from {last_model_str} to {current_model_str}. Clearing old cache to save space.") | |
| shutil.rmtree(cache_dir) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Save current model info | |
| with open(model_info_file, "w") as f: | |
| f.write(current_model_str) | |
| # --------------------------- | |
| logger.info(f"Checking for model {repo_id} ({filename}) in {cache_dir}...") | |
| try: | |
| # Download the model from HuggingFace Hub (this uses the cache automatically) | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir=cache_dir | |
| ) | |
| logger.info(f"Loading model into memory from {model_path}...") | |
| # n_ctx is the total context window (input tokens + output tokens). | |
| # 8192 is safe for a 4B GGUF model and prevents overflow errors on long | |
| # conversations or web-augmented queries that would fail at 4096. | |
| # HARDCODE n_threads to 2. HF Spaces free tier only gives 2 vCPUs. | |
| # os.cpu_count() returns the host machine's cores (often 64+) which causes extreme thread thrashing and destroys performance. | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=8192, | |
| n_threads=2, | |
| flash_attn=True, | |
| verbose=False | |
| ) | |
| logger.info(f"Successfully loaded {filename}") | |
| except Exception as e: | |
| logger.error(f"Error loading GGUF model: {e}") | |
| llm = None | |
| def generate_response_stream(history: list, query: str, max_new_tokens: int = 500): | |
| """ | |
| Generate a response using Llama's native chat completion, yielding chunks. | |
| history format: [{"role": "user", "content": "msg"}, {"role": "assistant", "content": "msg"}] | |
| """ | |
| global llm | |
| if not llm: | |
| logger.warning("Generate response called but model is not loaded. Returning placeholder.") | |
| yield "I am a placeholder AI assistant. Please ensure the model downloaded correctly." | |
| return | |
| # Append the new query to the history | |
| messages = history.copy() | |
| # Prepend system prompt if history is empty (optional but recommended for Llama 3) | |
| if not messages or messages[0].get("role") != "system": | |
| messages.insert(0, {"role": "system", "content": "You are a helpful AI assistant."}) | |
| messages.append({"role": "user", "content": query}) | |
| # Retry loop: if the prompt is still too long for the context window after | |
| # keeping the system message, drop the oldest user/assistant turn pair and | |
| # try again. Stop retrying once only the system message + current query remain. | |
| while True: | |
| try: | |
| response = llm.create_chat_completion( | |
| messages=messages, | |
| max_tokens=max_new_tokens, | |
| temperature=0.7, | |
| stream=True | |
| ) | |
| for chunk in response: | |
| delta = chunk["choices"][0].get("delta", {}) | |
| if "content" in delta: | |
| yield delta["content"] | |
| return | |
| except Exception as e: | |
| err_str = str(e).lower() | |
| if "exceed" in err_str and "context" in err_str: | |
| # Find the oldest non-system, non-latest-user message pair to drop | |
| # messages layout: [system?, ...history..., latest_user] | |
| # History pairs are at indices 1..-2 (excluding system and last user msg) | |
| start = 1 if messages[0].get("role") == "system" else 0 | |
| # Need at least one history turn (2 messages) to trim | |
| if len(messages) - start > 2: | |
| logger.warning( | |
| f"Context window overflow ({e}). " | |
| f"Dropping oldest history turn and retrying. " | |
| f"Messages remaining: {len(messages) - 2}" | |
| ) | |
| # Drop the oldest user+assistant pair (2 messages after system) | |
| messages = messages[:start] + messages[start + 2:] | |
| continue | |
| # Nothing left to trim — surface a clean error | |
| logger.error(f"Context window overflow even with minimal history: {e}") | |
| yield "I'm sorry, your query is too long for me to process. Please try a shorter message or start a new conversation." | |
| return | |
| else: | |
| logger.error(f"Error generating response: {e}") | |
| yield f"Error generating response: {e}" | |
| return | |