File size: 4,149 Bytes
b89e6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e615d55
b89e6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Phase 5: the CodeAssistant service - the heart of the deployable app.

Wraps a code LLM + an optional retrieval index and exposes:
  - generate(intent, mode="baseline"|"rag")
  - the prompt builders, so eval/agent code can reuse them.

Designed to be imported by the FastAPI / Gradio / Streamlit front-ends so all
surfaces share one implementation.
"""
from __future__ import annotations

import re
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parents[2]))
from src.config import load_config  # noqa: E402
from src.rag.embedder import CodeIndex  # noqa: E402

SYSTEM_PROMPT = (
    "You are an expert Python coding assistant. Write a single, correct, "
    "self-contained Python function for the request. Output only code."
)

_FENCE_RE = re.compile(r"```(?:python)?\n(.*?)```", re.DOTALL)


def extract_code(text: str) -> str:
    """Strip markdown fences if the model wrapped its answer."""
    m = _FENCE_RE.search(text)
    return m.group(1).strip() if m else text.strip()


class CodeAssistant:
    def __init__(self, gen_model: str, index: CodeIndex | None = None,
                 top_k: int = 3, device_map: str = "auto"):
        from transformers import AutoModelForCausalLM, AutoTokenizer

        self.gen_model = gen_model
        self.index = index
        self.top_k = top_k
        self.tok = AutoTokenizer.from_pretrained(gen_model)
        self.model = AutoModelForCausalLM.from_pretrained(
            gen_model, torch_dtype="auto", device_map=device_map
        )

    # ---- prompt builders ------------------------------------------------
    def baseline_messages(self, intent: str):
        return [{"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": f"# Task: {intent}"}]

    def rag_messages(self, intent: str, k: int | None = None):
        if self.index is None:
            return self.baseline_messages(intent)
        ex = self.index.retrieve(intent, k or self.top_k)
        blocks = [f"# Task: {r.docstring}\n{r.code}" for _, r in ex.iterrows()]
        context = "\n\n".join(blocks)
        user = (f"Here are similar reference examples:\n\n{context}\n\n"
                f"# Now write a function for this task:\n# Task: {intent}")
        return [{"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user}]

    # ---- generation -----------------------------------------------------
    def _generate(self, messages, max_new_tokens=320, temperature=0.0):
        text = self.tok.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tok(text, return_tensors="pt").to(self.model.device)
        do_sample = temperature and temperature > 0
        kwargs = dict(max_new_tokens=max_new_tokens, do_sample=do_sample,
                      pad_token_id=self.tok.eos_token_id)
        if do_sample:
            kwargs["temperature"] = temperature
        out = self.model.generate(**inputs, **kwargs)
        new = out[0][inputs.input_ids.shape[1]:]
        return self.tok.decode(new, skip_special_tokens=True)

    def generate(self, intent: str, mode: str = "rag", max_new_tokens=320,
                 temperature=0.0, return_sources=False):
        msgs = self.rag_messages(intent) if mode == "rag" else self.baseline_messages(intent)
        code = extract_code(self._generate(msgs, max_new_tokens, temperature))
        if return_sources and mode == "rag" and self.index is not None:
            srcs = self.index.retrieve(intent, self.top_k)[["docstring", "score"]]
            return code, srcs.to_dict("records")
        return code

    @classmethod
    def from_config(cls, cfg=None, with_index: bool = True) -> "CodeAssistant":
        cfg = cfg or load_config()
        index = None
        if with_index:
            idx_dir = Path(cfg.paths.index_dir)
            if (idx_dir / "code.index").exists():
                index = CodeIndex.load(str(idx_dir))
            else:
                print("[assistant] no saved index found; running baseline-only.")
        return cls(cfg.models.gen_model, index=index, top_k=cfg.models.top_k)