| """ |
| Component 8: Local chat interface using Gradio. |
| |
| - Clean dark-themed UI. |
| - Prompt input box. |
| - Syntax-highlighted code output (Python + JavaScript). |
| - Copy button for each code response. |
| - Generation time + token count. |
| - Conversation history in session. |
| - Clear button to reset history. |
| - Live model selector: Base / LoRA / INT8 (no restart). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import html |
| import re |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import yaml |
| from pygments import highlight |
| from pygments.formatters import HtmlFormatter |
| from pygments.lexers import JavascriptLexer, PythonLexer, TextLexer |
|
|
| from src.finetuning_system.lora_adapter import LoRAConfig, apply_lora, load_lora_state_dict |
| from src.inference_engine.inference_engine import DecodingConfig, InferenceEngine |
| from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets |
| from src.tokenizer.code_tokenizer import CodeTokenizer |
|
|
|
|
| def _load_yaml(path: Path) -> Dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Config file not found: {path}") |
| data = yaml.safe_load(path.read_text(encoding="utf-8-sig")) |
| if not isinstance(data, dict): |
| raise ValueError("Invalid YAML format.") |
| return data |
|
|
|
|
| def _build_model_config(path: Path) -> ModelConfig: |
| cfg = _load_yaml(path) |
| preset = cfg.get("preset") |
| model_cfg = cfg.get("model", {}) |
| if preset: |
| presets = get_model_presets() |
| if preset not in presets: |
| raise ValueError(f"Unknown preset: {preset}") |
| merged = presets[preset].__dict__.copy() |
| merged.update(model_cfg) |
| return ModelConfig(**merged) |
| return ModelConfig(**model_cfg) |
|
|
|
|
| def _guess_language(prompt: str, default_lang: str = "python") -> str: |
| p = prompt.lower() |
| if "javascript" in p or " js " in f" {p} " or "node" in p: |
| return "javascript" |
| if "python" in p: |
| return "python" |
| return default_lang |
|
|
|
|
| def _is_coding_prompt(prompt: str) -> bool: |
| p = prompt.lower().strip() |
| coding_keywords = [ |
| "code", |
| "python", |
| "javascript", |
| "function", |
| "bug", |
| "error", |
| "algorithm", |
| "sort", |
| "loop", |
| "class", |
| "api", |
| "sql", |
| "regex", |
| "debug", |
| "implement", |
| "write", |
| ] |
| if any(k in p for k in coding_keywords): |
| return True |
| if re.fullmatch(r"(hi|hello|hey|yo|hola)[!. ]*", p): |
| return False |
| return False |
|
|
|
|
| def _highlight_code(code: str, language: str) -> str: |
| code = code or "" |
| if language == "javascript": |
| lexer = JavascriptLexer() |
| elif language == "python": |
| lexer = PythonLexer() |
| else: |
| lexer = TextLexer() |
| formatter = HtmlFormatter(nowrap=True) |
| return highlight(code, lexer, formatter) |
|
|
|
|
| def _render_history(history: List[Dict[str, Any]]) -> str: |
| formatter = HtmlFormatter(style="monokai") |
| css = formatter.get_style_defs(".codehilite") |
| blocks = [ |
| "<style>", |
| css, |
| """ |
| .chat-wrap { background: #0f1117; color: #e5e7eb; padding: 14px; border-radius: 12px; font-family: 'Segoe UI', sans-serif; } |
| .entry { border: 1px solid #262a33; background: #151922; border-radius: 10px; padding: 12px; margin-bottom: 12px; } |
| .prompt { color: #93c5fd; font-weight: 600; margin-bottom: 8px; white-space: pre-wrap; } |
| .meta { color: #9ca3af; font-size: 12px; margin-top: 8px; } |
| .code-box { border: 1px solid #2f3542; border-radius: 8px; background: #0b0d12; overflow-x: auto; } |
| .code-inner { padding: 12px; font-family: Consolas, 'Courier New', monospace; font-size: 13px; line-height: 1.5; white-space: pre; } |
| .copy-btn { background: #1f2937; color: #e5e7eb; border: 1px solid #374151; border-radius: 6px; padding: 5px 10px; cursor: pointer; float: right; margin-bottom: 6px; } |
| .copy-btn:hover { background: #374151; } |
| .label { font-size: 12px; color: #a1a1aa; margin-bottom: 6px; } |
| """, |
| "</style>", |
| """ |
| <script> |
| function copyCode(id) { |
| const el = document.getElementById(id); |
| if (!el) return; |
| const text = el.innerText; |
| navigator.clipboard.writeText(text); |
| } |
| </script> |
| """, |
| '<div class="chat-wrap">', |
| ] |
|
|
| if not history: |
| blocks.append('<div class="entry"><div class="meta">No messages yet. Ask a coding question to begin.</div></div>') |
|
|
| for i, item in enumerate(history, start=1): |
| lang = item.get("language", "python") |
| prompt = html.escape(str(item.get("prompt", ""))) |
| highlighted = _highlight_code(str(item.get("code", "")), lang) |
| code_id = f"code-{i}" |
| syntax_ok = "yes" if item.get("syntax_ok", False) else "n/a" |
| mode = item.get("mode", "base") |
| blocks.append('<div class="entry">') |
| blocks.append(f'<div class="prompt">User: {prompt}</div>') |
| blocks.append(f'<div class="label">Assistant ({lang})</div>') |
| blocks.append(f'<button class="copy-btn" onclick="copyCode(\'{code_id}\')">Copy</button>') |
| blocks.append('<div style="clear: both"></div>') |
| blocks.append('<div class="code-box">') |
| blocks.append(f'<pre class="code-inner codehilite" id="{code_id}">{highlighted}</pre>') |
| blocks.append('</div>') |
| blocks.append( |
| f'<div class="meta">mode={mode} | time={item.get("time_sec", 0):.2f}s | ' |
| f'tokens={item.get("tokens", 0)} | syntax_ok={syntax_ok} | ' |
| f'attempt={item.get("attempt", 1)}</div>' |
| ) |
| blocks.append('</div>') |
|
|
| blocks.append('</div>') |
| return "\n".join(blocks) |
|
|
|
|
| class ChatRuntime: |
| def __init__(self, config_path: str) -> None: |
| self.project_root = Path(__file__).resolve().parents[2] |
| self.cfg = _load_yaml(self.project_root / config_path) |
|
|
| self.model_cfg = _build_model_config(self.project_root / self.cfg["model"]["model_config_path"]) |
| self.cuda_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if self.cuda_device.type != "cuda": |
| raise RuntimeError("CUDA GPU is required for this chat interface setup.") |
|
|
| self.tokenizer = CodeTokenizer.load(str(self.project_root / self.cfg["model"]["tokenizer_dir"])) |
|
|
| self.decode_cfg = DecodingConfig( |
| max_new_tokens=int(self.cfg["inference"].get("max_new_tokens", 300)), |
| greedy_temperature=float(self.cfg["inference"].get("greedy_temperature", 0.0)), |
| retry2_temperature=float(self.cfg["inference"].get("retry2_temperature", 0.25)), |
| retry2_top_p=float(self.cfg["inference"].get("retry2_top_p", 0.85)), |
| retry3_temperature=float(self.cfg["inference"].get("retry3_temperature", 0.35)), |
| retry3_top_p=float(self.cfg["inference"].get("retry3_top_p", 0.90)), |
| max_retries=int(self.cfg["inference"].get("max_retries", 3)), |
| min_tokens_before_stop_check=int(self.cfg["inference"].get("min_tokens_before_stop_check", 64)), |
| ) |
|
|
| self.current_mode: Optional[str] = None |
| self.engine: Optional[InferenceEngine] = None |
|
|
| def _release_current(self) -> None: |
| self.engine = None |
| self.current_mode = None |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| def _current_vram_gb(self) -> float: |
| if not torch.cuda.is_available(): |
| return 0.0 |
| return float(torch.cuda.memory_allocated() / (1024**3)) |
|
|
| def _status_text(self, mode: str, load_sec: float) -> str: |
| return f"MINDI 1.0 420M | mode={mode} | load={load_sec:.2f}s | vram={self._current_vram_gb():.2f}GB" |
|
|
| def _load_base_model(self) -> InferenceEngine: |
| model = CodeTransformerLM(self.model_cfg).to(self.cuda_device) |
| payload = torch.load(self.project_root / self.cfg["model"]["base_checkpoint_path"], map_location=self.cuda_device) |
| model.load_state_dict(payload["model_state"]) |
| model.half() |
| return InferenceEngine(model=model, tokenizer=self.tokenizer, device=self.cuda_device) |
|
|
| def _load_lora_model(self) -> InferenceEngine: |
| model = CodeTransformerLM(self.model_cfg).to(self.cuda_device) |
| payload = torch.load(self.project_root / self.cfg["model"]["base_checkpoint_path"], map_location=self.cuda_device) |
| model.load_state_dict(payload["model_state"]) |
|
|
| lora_cfg = LoRAConfig( |
| r=int(self.cfg.get("lora", {}).get("r", 8)), |
| alpha=int(self.cfg.get("lora", {}).get("alpha", 16)), |
| dropout=float(self.cfg.get("lora", {}).get("dropout", 0.05)), |
| target_keywords=list(self.cfg.get("lora", {}).get("target_keywords", ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"])), |
| ) |
| apply_lora(model, lora_cfg) |
| model = model.to(self.cuda_device) |
|
|
| lora_payload = torch.load(self.project_root / self.cfg["model"]["lora_adapter_path"], map_location=self.cuda_device) |
| lora_state = lora_payload.get("lora_state", lora_payload) |
| load_lora_state_dict(model, lora_state) |
| model.half() |
| return InferenceEngine(model=model, tokenizer=self.tokenizer, device=self.cuda_device) |
|
|
| def _load_int8_model(self) -> InferenceEngine: |
| cpu = torch.device("cpu") |
| model = CodeTransformerLM(self.model_cfg).to(cpu).float() |
| model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) |
| q_state = torch.load(self.project_root / self.cfg["model"]["quantized_state_path"], map_location=cpu) |
| model.load_state_dict(q_state) |
| return InferenceEngine(model=model, tokenizer=self.tokenizer, device=cpu) |
|
|
| def _ensure_mode(self, mode: str) -> str: |
| mode = (mode or "base").lower().strip() |
| if mode not in {"base", "lora", "int8"}: |
| mode = "base" |
|
|
| if self.current_mode == mode and self.engine is not None: |
| return self._status_text(mode, load_sec=0.0) |
|
|
| t0 = time.perf_counter() |
| self._release_current() |
| if mode == "base": |
| self.engine = self._load_base_model() |
| elif mode == "lora": |
| self.engine = self._load_lora_model() |
| else: |
| self.engine = self._load_int8_model() |
|
|
| self.current_mode = mode |
| load_sec = time.perf_counter() - t0 |
| return self._status_text(mode, load_sec=load_sec) |
|
|
| def switch_mode(self, mode: str) -> str: |
| return self._ensure_mode(mode) |
|
|
| def respond(self, prompt: str, history: List[Dict[str, Any]], mode: str) -> Tuple[str, List[Dict[str, Any]], str, str]: |
| prompt = (prompt or "").strip() |
| if not prompt: |
| status = self._ensure_mode(mode) |
| return _render_history(history), history, "", status |
|
|
| status = self._ensure_mode(mode) |
|
|
| if not _is_coding_prompt(prompt): |
| fallback = "Please ask a coding question (for example: 'Write a Python function to ...' or 'Fix this JavaScript bug ...')." |
| history.append( |
| { |
| "prompt": prompt, |
| "code": fallback, |
| "language": "text", |
| "tokens": 0, |
| "time_sec": 0.0, |
| "syntax_ok": None, |
| "attempt": 0, |
| "mode": self.current_mode or "base", |
| } |
| ) |
| return _render_history(history), history, "", status |
|
|
| lang_default = str(self.cfg["inference"].get("language_default", "python")) |
| language = _guess_language(prompt, default_lang=lang_default) |
|
|
| start = time.perf_counter() |
| result = self.engine.generate_with_retry(prompt=prompt, language=language, cfg=self.decode_cfg) |
| elapsed = time.perf_counter() - start |
|
|
| final = result["final"] |
| history.append( |
| { |
| "prompt": prompt, |
| "code": final["code"], |
| "language": language, |
| "tokens": int(final.get("generated_tokens", 0)), |
| "time_sec": float(elapsed), |
| "syntax_ok": bool(final.get("syntax_ok", False)) if language == "python" else None, |
| "attempt": int(final.get("attempt", 1)), |
| "mode": self.current_mode or "base", |
| } |
| ) |
|
|
| return _render_history(history), history, "", status |
|
|
| def clear(self, mode: str) -> Tuple[str, List[Dict[str, Any]], str, str]: |
| history: List[Dict[str, Any]] = [] |
| status = self._ensure_mode(mode) |
| return _render_history(history), history, "", status |
|
|
|
|
| def create_demo(config_path: str = "configs/component8_chat_config.yaml") -> gr.Blocks: |
| runtime = ChatRuntime(config_path=config_path) |
|
|
| with gr.Blocks(title="MINDI 1.0 420M", theme=gr.themes.Base()) as demo: |
| gr.Markdown("## MINDI 1.0 420M\nYour local coding intelligence — 420M parameters, fully offline") |
|
|
| history_state = gr.State([]) |
| chat_html = gr.HTML(value=_render_history([])) |
|
|
| with gr.Row(): |
| mode_dropdown = gr.Dropdown( |
| label="Model Mode", |
| choices=["base", "lora", "int8"], |
| value="base", |
| interactive=True, |
| ) |
| status_box = gr.Textbox(label="Status", value="MINDI 1.0 420M | mode=base | load=0.00s | vram=0.00GB", interactive=False) |
|
|
| prompt_box = gr.Textbox( |
| label="Your Prompt", |
| lines=4, |
| placeholder="Ask MINDI anything about code", |
| ) |
|
|
| with gr.Row(): |
| send_btn = gr.Button("Generate", variant="primary") |
| clear_btn = gr.Button("Clear Conversation") |
| switch_btn = gr.Button("Apply Mode") |
|
|
| switch_btn.click(fn=runtime.switch_mode, inputs=[mode_dropdown], outputs=[status_box]) |
|
|
| send_btn.click( |
| fn=runtime.respond, |
| inputs=[prompt_box, history_state, mode_dropdown], |
| outputs=[chat_html, history_state, prompt_box, status_box], |
| queue=True, |
| ) |
| prompt_box.submit( |
| fn=runtime.respond, |
| inputs=[prompt_box, history_state, mode_dropdown], |
| outputs=[chat_html, history_state, prompt_box, status_box], |
| queue=True, |
| ) |
| clear_btn.click( |
| fn=runtime.clear, |
| inputs=[mode_dropdown], |
| outputs=[chat_html, history_state, prompt_box, status_box], |
| ) |
|
|
| return demo |
|
|
|
|
|
|