substrate / app /generation.py
Syed Taha
refactor: streamline LLM call and maintain internal conversation history
39bcfbd
"""
app/generation.py
LLM generation wrapper supporting two providers:
- hf_router: HuggingFace router (cloud, free tier)
- local: Ollama (local, GPU recommended)
Provider is set in config.yaml -> generation.provider
Switch between them with one config change - no code changes needed.
Usage:
from app.generation import Generator
g = Generator()
answer = g.generate("how does numpy clip work?", chunks=[]) # no RAG
answer = g.generate("how does numpy clip work?", chunks=chunks) # with RAG
"""
import json
import logging
import os
import time
import requests
import yaml
log = logging.getLogger(__name__)
def load_config(path: str = "config.yaml") -> dict:
with open(path) as f:
return yaml.safe_load(f)
# Prompts - loaded from config.yaml -> generation.system_prompt
# Edit the prompt in config.yaml, not here.
RAG_TEMPLATE = """The following code was retrieved from the relevant repositories:
{context}
---
Question: {query}
Answer based on the code above:"""
NO_RAG_TEMPLATE = """Question: {query}
Answer:"""
class Generator:
"""
LLM generation wrapper.
Supports HF router (cloud) and Ollama (local) via config.yaml.
Maintains internal conversation history for multi-turn conversations.
"""
def __init__(self, config_path: str = "config.yaml"):
self.cfg = load_config(config_path)
gen_cfg = self.cfg["generation"]
self.provider = gen_cfg.get("provider", "hf_router")
self.model = gen_cfg["model"]
self.max_new_tokens = gen_cfg.get("max_new_tokens", 512)
self.temperature = gen_cfg.get("temperature", 0.1)
self.hf_token = os.environ.get("HF_API_TOKEN", "")
# Load system prompt from config (single source of truth)
self.system_prompt = gen_cfg.get("system_prompt", "You are a helpful coding assistant.")
# Internal conversation history
self.conversation_history: list[dict] = []
# HF Router
def _call_hf_router(self, messages: list[dict], retries: int = 3,
retry_delay: float = 10.0, ) -> str | None:
url = "https://router.huggingface.co/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.hf_token}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"messages": messages,
"max_tokens": self.max_new_tokens,
"temperature": self.temperature,
}
for attempt in range(retries):
try:
resp = requests.post(
url, headers=headers,
data=json.dumps(payload), timeout=90
)
if resp.status_code == 200:
return resp.json()["choices"][0]["message"]["content"].strip()
elif resp.status_code in (429, 529):
wait = retry_delay * (2 ** attempt)
log.warning("Rate limited - waiting %.0fs...", wait)
time.sleep(wait)
elif resp.status_code == 503:
wait = retry_delay * (attempt + 1)
log.warning("Model loading - waiting %.0fs...", wait)
time.sleep(wait)
else:
log.error("HF API error %d: %s", resp.status_code, resp.text[:300])
return None
except requests.exceptions.Timeout:
log.warning("Timeout (attempt %d/%d)", attempt + 1, retries)
time.sleep(retry_delay)
except Exception as e:
log.error("Request error: %s", e)
return None
log.error("All %d attempts failed", retries)
return None
# Ollama (local)
def _call_ollama(self, messages: list[dict], retries: int = 3) -> str | None:
"""
Call local Ollama server (http://localhost:11434).
Install: curl -fsSL https://ollama.ai/install.sh | sh
Pull model: ollama pull mistral (or llama3.1, phi3, etc.)
"""
url = "http://localhost:11434/api/chat"
payload = {
"model": self.model,
"messages": messages,
"stream": False,
"options": {
"temperature": self.temperature,
"num_predict": self.max_new_tokens,
},
}
for attempt in range(retries):
try:
resp = requests.post(url, json=payload, timeout=120)
if resp.status_code == 200:
return resp.json()["message"]["content"].strip()
else:
log.error("Ollama error %d: %s", resp.status_code, resp.text[:200])
return None
except requests.exceptions.ConnectionError:
log.error(
"Cannot connect to Ollama at localhost:11434. "
"Is it running? Start with: ollama serve"
)
return None
except Exception as e:
log.error("Ollama request error (attempt %d): %s", attempt + 1, e)
time.sleep(2)
return None
# Public interface
def generate(self, query: str, chunks: list[dict] | None = None,
context_str: str | None = None, use_history: bool = True) -> dict:
"""
Generate an answer for a query.
Args:
query: The user's question
chunks: Retrieved chunks (optional - if None, no RAG)
context_str: Pre-formatted context string (overrides chunks)
use_history: If True, include and update internal conversation history (default: True)
Returns:
dict with keys:
answer: Generated text
has_rag: Whether RAG context was used
model: Model name used
provider: Provider used
duration_s: Time taken
"""
t0 = time.time()
# Build context
if context_str:
context = context_str
elif chunks:
from app.retrieval import Retriever
r = Retriever.__new__(Retriever)
r.cfg = self.cfg
context = r.format_context(chunks)
else:
context = None
has_rag = bool(context)
# Build messages
if has_rag:
user_content = RAG_TEMPLATE.format(context=context, query=query)
else:
user_content = NO_RAG_TEMPLATE.format(query=query)
messages = [
{"role": "system", "content": self.system_prompt},
]
# Add internal conversation history if enabled
if use_history:
messages.extend(self.conversation_history)
# Add current query
messages.append({"role": "user", "content": user_content})
# Call provider
if self.provider == "local":
answer = self._call_ollama(messages)
else:
answer = self._call_hf_router(messages)
duration = time.time() - t0
# Update internal history if enabled
if use_history:
self.conversation_history.append({"role": "user", "content": user_content})
self.conversation_history.append({"role": "assistant", "content": answer or "Error: no response from model."})
return {
"answer": answer or "Error: no response from model.",
"has_rag": has_rag,
"model": self.model,
"provider": self.provider,
"duration_s": round(duration, 2),
}
def clear_history(self) -> None:
"""Clear the internal conversation history."""
self.conversation_history = []
def get_history(self) -> list[dict]:
"""Get the current conversation history."""
return self.conversation_history.copy()