code-gen-assistant / src /rag /generator.py
Rushabh147's picture
fix dtype kwarg + cut max_new_tokens to 64 for faster CPU inference
e615d55
Raw
History Blame Contribute Delete
4.15 kB
"""Phase 5: the CodeAssistant service - the heart of the deployable app.
Wraps a code LLM + an optional retrieval index and exposes:
- generate(intent, mode="baseline"|"rag")
- the prompt builders, so eval/agent code can reuse them.
Designed to be imported by the FastAPI / Gradio / Streamlit front-ends so all
surfaces share one implementation.
"""
from __future__ import annotations
import re
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
from src.config import load_config # noqa: E402
from src.rag.embedder import CodeIndex # noqa: E402
SYSTEM_PROMPT = (
"You are an expert Python coding assistant. Write a single, correct, "
"self-contained Python function for the request. Output only code."
)
_FENCE_RE = re.compile(r"```(?:python)?\n(.*?)```", re.DOTALL)
def extract_code(text: str) -> str:
"""Strip markdown fences if the model wrapped its answer."""
m = _FENCE_RE.search(text)
return m.group(1).strip() if m else text.strip()
class CodeAssistant:
def __init__(self, gen_model: str, index: CodeIndex | None = None,
top_k: int = 3, device_map: str = "auto"):
from transformers import AutoModelForCausalLM, AutoTokenizer
self.gen_model = gen_model
self.index = index
self.top_k = top_k
self.tok = AutoTokenizer.from_pretrained(gen_model)
self.model = AutoModelForCausalLM.from_pretrained(
gen_model, torch_dtype="auto", device_map=device_map
)
# ---- prompt builders ------------------------------------------------
def baseline_messages(self, intent: str):
return [{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"# Task: {intent}"}]
def rag_messages(self, intent: str, k: int | None = None):
if self.index is None:
return self.baseline_messages(intent)
ex = self.index.retrieve(intent, k or self.top_k)
blocks = [f"# Task: {r.docstring}\n{r.code}" for _, r in ex.iterrows()]
context = "\n\n".join(blocks)
user = (f"Here are similar reference examples:\n\n{context}\n\n"
f"# Now write a function for this task:\n# Task: {intent}")
return [{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user}]
# ---- generation -----------------------------------------------------
def _generate(self, messages, max_new_tokens=320, temperature=0.0):
text = self.tok.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = self.tok(text, return_tensors="pt").to(self.model.device)
do_sample = temperature and temperature > 0
kwargs = dict(max_new_tokens=max_new_tokens, do_sample=do_sample,
pad_token_id=self.tok.eos_token_id)
if do_sample:
kwargs["temperature"] = temperature
out = self.model.generate(**inputs, **kwargs)
new = out[0][inputs.input_ids.shape[1]:]
return self.tok.decode(new, skip_special_tokens=True)
def generate(self, intent: str, mode: str = "rag", max_new_tokens=320,
temperature=0.0, return_sources=False):
msgs = self.rag_messages(intent) if mode == "rag" else self.baseline_messages(intent)
code = extract_code(self._generate(msgs, max_new_tokens, temperature))
if return_sources and mode == "rag" and self.index is not None:
srcs = self.index.retrieve(intent, self.top_k)[["docstring", "score"]]
return code, srcs.to_dict("records")
return code
@classmethod
def from_config(cls, cfg=None, with_index: bool = True) -> "CodeAssistant":
cfg = cfg or load_config()
index = None
if with_index:
idx_dir = Path(cfg.paths.index_dir)
if (idx_dir / "code.index").exists():
index = CodeIndex.load(str(idx_dir))
else:
print("[assistant] no saved index found; running baseline-only.")
return cls(cfg.models.gen_model, index=index, top_k=cfg.models.top_k)