""" Component 7: Inference engine for local code generation. Features: - Deterministic low-temperature greedy mode. - Stop rules for clean function completion. - Syntax-aware retry with up to 3 attempts. """ from __future__ import annotations import ast from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch from src.evaluation_system.code_eval import restore_code_from_structured from src.model_architecture.code_transformer import CodeTransformerLM from src.tokenizer.code_tokenizer import CodeTokenizer @dataclass class DecodingConfig: max_new_tokens: int = 300 # Mode 1: deterministic output greedy_temperature: float = 0.0 # Retry mode 2 retry2_temperature: float = 0.25 retry2_top_p: float = 0.85 # Retry mode 3 retry3_temperature: float = 0.35 retry3_top_p: float = 0.90 max_retries: int = 3 min_tokens_before_stop_check: int = 64 # Stop only when function body is non-trivial. min_function_body_statements: int = 2 class InferenceEngine: def __init__(self, model: CodeTransformerLM, tokenizer: CodeTokenizer, device: torch.device) -> None: self.model = model self.tokenizer = tokenizer self.device = device self.model.eval() @staticmethod def _syntax_ok_python(code: str) -> bool: try: ast.parse(code) return True except Exception: return False @staticmethod def _function_completion_score(code: str) -> int: # Higher score = more complete usable function. try: tree = ast.parse(code) except Exception: return 0 funcs = [n for n in tree.body if isinstance(n, ast.FunctionDef)] if not funcs: return 0 fn = funcs[-1] body_len = len(fn.body) has_return = any(isinstance(n, ast.Return) for n in ast.walk(fn)) return body_len + (2 if has_return else 0) def _looks_complete_function(self, code: str, min_body_statements: int) -> bool: if "def " not in code: return False try: tree = ast.parse(code) except Exception: return False funcs = [n for n in tree.body if isinstance(n, ast.FunctionDef)] if not funcs: return False fn = funcs[-1] if len(fn.body) < min_body_statements: return False return True def _sample_next( self, logits: torch.Tensor, temperature: float, top_p: float, ) -> torch.Tensor: if temperature <= 0: return torch.argmax(logits, dim=-1, keepdim=True) logits = logits / temperature probs = torch.softmax(logits, dim=-1) sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumulative = torch.cumsum(sorted_probs, dim=-1) cutoff = cumulative > top_p cutoff[..., 1:] = cutoff[..., :-1].clone() cutoff[..., 0] = False sorted_probs[cutoff] = 0.0 denom = sorted_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12) sorted_probs = sorted_probs / denom sampled = torch.multinomial(sorted_probs, num_samples=1) return sorted_idx.gather(-1, sampled) @torch.no_grad() def _generate_once( self, prompt: str, language: str, max_new_tokens: int, temperature: float, top_p: float, min_tokens_before_stop_check: int, min_function_body_statements: int, ) -> Dict[str, object]: prompt_text = self.tokenizer.format_training_sample(prompt=prompt, code="", language=language) prompt_text = prompt_text.replace(" ", "").strip() ids = self.tokenizer.encode(prompt_text) eos_id = self.tokenizer.special_token_ids.get("") # Remove trailing EOS so generation can continue. if eos_id is not None and len(ids) > 1 and ids[-1] == int(eos_id): ids = ids[:-1] input_ids = torch.tensor([ids], dtype=torch.long, device=self.device) generated_steps = 0 for _ in range(max_new_tokens): out = self.model(input_ids=input_ids) logits = out["logits"][:, -1, :] next_id = self._sample_next(logits, temperature=temperature, top_p=top_p) input_ids = torch.cat([input_ids, next_id], dim=1) generated_steps += 1 # Primary stop: EOS token. if eos_id is not None and int(next_id.item()) == int(eos_id): break # Secondary stop: complete parseable function with non-trivial body. if generated_steps >= min_tokens_before_stop_check and (generated_steps % 12 == 0): decoded = self.tokenizer.decode(input_ids[0].tolist()) code = restore_code_from_structured(decoded) if self._looks_complete_function(code, min_body_statements=min_function_body_statements): break decoded = self.tokenizer.decode(input_ids[0].tolist()) code = restore_code_from_structured(decoded) syntax_ok = self._syntax_ok_python(code) if language == "python" else True completion_score = self._function_completion_score(code) if language == "python" else 0 return { "code": code, "syntax_ok": syntax_ok, "generated_tokens": generated_steps, "temperature": temperature, "top_p": top_p, "completion_score": completion_score, } @torch.no_grad() def generate_with_retry( self, prompt: str, language: str = "python", cfg: Optional[DecodingConfig] = None, ) -> Dict[str, object]: cfg = cfg or DecodingConfig() attempts: List[Tuple[float, float]] = [ (cfg.greedy_temperature, 1.0), (cfg.retry2_temperature, cfg.retry2_top_p), (cfg.retry3_temperature, cfg.retry3_top_p), ] results = [] for i in range(min(cfg.max_retries, len(attempts))): temp, top_p = attempts[i] res = self._generate_once( prompt=prompt, language=language, max_new_tokens=cfg.max_new_tokens, temperature=temp, top_p=top_p, min_tokens_before_stop_check=cfg.min_tokens_before_stop_check, min_function_body_statements=cfg.min_function_body_statements, ) res["attempt"] = i + 1 results.append(res) # Syntax-aware retry: stop retries as soon as syntax is valid. if bool(res["syntax_ok"]): return { "final": res, "attempts": results, "used_retry": i > 0, } # If all retries fail, choose best completion score then longest generation. best = sorted( results, key=lambda x: (int(x.get("completion_score", 0)), int(x.get("generated_tokens", 0))), reverse=True, )[0] return { "final": best, "attempts": results, "used_retry": True, }