""" app/generation.py LLM generation wrapper supporting two providers: - hf_router: HuggingFace router (cloud, free tier) - local: Ollama (local, GPU recommended) Provider is set in config.yaml -> generation.provider Switch between them with one config change - no code changes needed. Usage: from app.generation import Generator g = Generator() answer = g.generate("how does numpy clip work?", chunks=[]) # no RAG answer = g.generate("how does numpy clip work?", chunks=chunks) # with RAG """ import json import logging import os import time import requests import yaml log = logging.getLogger(__name__) def load_config(path: str = "config.yaml") -> dict: with open(path) as f: return yaml.safe_load(f) # Prompts - loaded from config.yaml -> generation.system_prompt # Edit the prompt in config.yaml, not here. RAG_TEMPLATE = """The following code was retrieved from the relevant repositories: {context} --- Question: {query} Answer based on the code above:""" NO_RAG_TEMPLATE = """Question: {query} Answer:""" class Generator: """ LLM generation wrapper. Supports HF router (cloud) and Ollama (local) via config.yaml. Maintains internal conversation history for multi-turn conversations. """ def __init__(self, config_path: str = "config.yaml"): self.cfg = load_config(config_path) gen_cfg = self.cfg["generation"] self.provider = gen_cfg.get("provider", "hf_router") self.model = gen_cfg["model"] self.max_new_tokens = gen_cfg.get("max_new_tokens", 512) self.temperature = gen_cfg.get("temperature", 0.1) self.hf_token = os.environ.get("HF_API_TOKEN", "") # Load system prompt from config (single source of truth) self.system_prompt = gen_cfg.get("system_prompt", "You are a helpful coding assistant.") # Internal conversation history self.conversation_history: list[dict] = [] # HF Router def _call_hf_router(self, messages: list[dict], retries: int = 3, retry_delay: float = 10.0, ) -> str | None: url = "https://router.huggingface.co/v1/chat/completions" headers = { "Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json", } payload = { "model": self.model, "messages": messages, "max_tokens": self.max_new_tokens, "temperature": self.temperature, } for attempt in range(retries): try: resp = requests.post( url, headers=headers, data=json.dumps(payload), timeout=90 ) if resp.status_code == 200: return resp.json()["choices"][0]["message"]["content"].strip() elif resp.status_code in (429, 529): wait = retry_delay * (2 ** attempt) log.warning("Rate limited - waiting %.0fs...", wait) time.sleep(wait) elif resp.status_code == 503: wait = retry_delay * (attempt + 1) log.warning("Model loading - waiting %.0fs...", wait) time.sleep(wait) else: log.error("HF API error %d: %s", resp.status_code, resp.text[:300]) return None except requests.exceptions.Timeout: log.warning("Timeout (attempt %d/%d)", attempt + 1, retries) time.sleep(retry_delay) except Exception as e: log.error("Request error: %s", e) return None log.error("All %d attempts failed", retries) return None # Ollama (local) def _call_ollama(self, messages: list[dict], retries: int = 3) -> str | None: """ Call local Ollama server (http://localhost:11434). Install: curl -fsSL https://ollama.ai/install.sh | sh Pull model: ollama pull mistral (or llama3.1, phi3, etc.) """ url = "http://localhost:11434/api/chat" payload = { "model": self.model, "messages": messages, "stream": False, "options": { "temperature": self.temperature, "num_predict": self.max_new_tokens, }, } for attempt in range(retries): try: resp = requests.post(url, json=payload, timeout=120) if resp.status_code == 200: return resp.json()["message"]["content"].strip() else: log.error("Ollama error %d: %s", resp.status_code, resp.text[:200]) return None except requests.exceptions.ConnectionError: log.error( "Cannot connect to Ollama at localhost:11434. " "Is it running? Start with: ollama serve" ) return None except Exception as e: log.error("Ollama request error (attempt %d): %s", attempt + 1, e) time.sleep(2) return None # Public interface def generate(self, query: str, chunks: list[dict] | None = None, context_str: str | None = None, use_history: bool = True) -> dict: """ Generate an answer for a query. Args: query: The user's question chunks: Retrieved chunks (optional - if None, no RAG) context_str: Pre-formatted context string (overrides chunks) use_history: If True, include and update internal conversation history (default: True) Returns: dict with keys: answer: Generated text has_rag: Whether RAG context was used model: Model name used provider: Provider used duration_s: Time taken """ t0 = time.time() # Build context if context_str: context = context_str elif chunks: from app.retrieval import Retriever r = Retriever.__new__(Retriever) r.cfg = self.cfg context = r.format_context(chunks) else: context = None has_rag = bool(context) # Build messages if has_rag: user_content = RAG_TEMPLATE.format(context=context, query=query) else: user_content = NO_RAG_TEMPLATE.format(query=query) messages = [ {"role": "system", "content": self.system_prompt}, ] # Add internal conversation history if enabled if use_history: messages.extend(self.conversation_history) # Add current query messages.append({"role": "user", "content": user_content}) # Call provider if self.provider == "local": answer = self._call_ollama(messages) else: answer = self._call_hf_router(messages) duration = time.time() - t0 # Update internal history if enabled if use_history: self.conversation_history.append({"role": "user", "content": user_content}) self.conversation_history.append({"role": "assistant", "content": answer or "Error: no response from model."}) return { "answer": answer or "Error: no response from model.", "has_rag": has_rag, "model": self.model, "provider": self.provider, "duration_s": round(duration, 2), } def clear_history(self) -> None: """Clear the internal conversation history.""" self.conversation_history = [] def get_history(self) -> list[dict]: """Get the current conversation history.""" return self.conversation_history.copy()