"""Gradio application for CodeWraith inference. Provides a web interface for generating technical specifications from Python source code using the fine-tuned student model. Deployed on HuggingFace Spaces for remote access (instructor evaluation). Sampling parameters (temperature, top_p, max_tokens) are exposed as UI controls for experimentation. """ from __future__ import annotations from pathlib import Path from typing import Any from codewraith import SYSTEM_MESSAGE EXAMPLE_CODE = '''\ def fibonacci(n: int) -> list[int]: """Generate the first n Fibonacci numbers.""" if n <= 0: return [] sequence = [0, 1] while len(sequence) < n: sequence.append(sequence[-1] + sequence[-2]) return sequence[:n] ''' # Global model state _model = None _tokenizer = None _retriever = None def load_model( adapter_dir: str = "./models/codewraith-lora-3b", model_key: str = "3b", ) -> tuple[Any, Any]: """Load the fine-tuned model and LoRA adapter. Args: adapter_dir: Path to the LoRA adapter directory. model_key: Base model key ("3b" or "8b"). Returns: Tuple of (model, tokenizer). """ global _model, _tokenizer # noqa: PLW0603 if _model is not None: return _model, _tokenizer from peft import PeftModel from unsloth import FastLanguageModel from codewraith.student.trainer import load_base_model model, tokenizer = load_base_model(model_key) model = PeftModel.from_pretrained(model, adapter_dir) FastLanguageModel.for_inference(model) _model, _tokenizer = model, tokenizer return model, tokenizer def init_retriever() -> Any: """Initialize the RAG retriever if the index exists.""" global _retriever # noqa: PLW0603 if _retriever is not None: return _retriever try: from codewraith.app.retriever import SpecRetriever retriever = SpecRetriever() if Path("data/chromadb").exists(): collection = retriever._get_collection() if collection.count() > 0: _retriever = retriever print(f"RAG retriever loaded ({collection.count()} examples)") return _retriever except ImportError: pass return None def generate_spec( source_code: str, temperature: float = 0.7, top_p: float = 0.9, max_tokens: int = 2048, use_rag: bool = True, ) -> str: """Generate a technical specification from Python source code. Uses RAG to retrieve similar code/spec pairs as few-shot context when available, improving generation quality. Args: source_code: Python source code to analyze. temperature: Sampling temperature (higher = more creative). top_p: Nucleus sampling threshold. max_tokens: Maximum tokens to generate. use_rag: Whether to use RAG retrieval for context. Returns: Generated Markdown specification. """ if not source_code.strip(): return "*Please paste some Python source code.*" model, tokenizer = load_model() # Build user content with optional RAG context user_content = source_code if use_rag: retriever = init_retriever() if retriever is not None: examples = retriever.retrieve(source_code, n_results=3) if examples: context = retriever.format_context(examples) user_content = context + source_code messages = [ {"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": user_content}, ] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device) outputs = model.generate( input_ids=inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) generated = outputs[0][inputs.shape[-1] :] return tokenizer.decode(generated, skip_special_tokens=True) def create_app(): """Create the Gradio application interface. Returns: A Gradio Blocks app ready to .launch(). """ import gradio as gr mermaid_css = """ .mermaid .node rect, .mermaid .node polygon, .mermaid .node circle { fill: #e8f0fe !important; stroke: #4a6fa5 !important; } .mermaid .nodeLabel, .mermaid .edgeLabel, .mermaid text { color: #1a1a1a !important; fill: #1a1a1a !important; } .mermaid .edgePath .path { stroke: #4a6fa5 !important; } """ with gr.Blocks( title="CodeWraith - Module-to-Spec Transformer", theme=gr.themes.Soft(), css=mermaid_css, ) as app: gr.Markdown( "# CodeWraith\n" "Generate technical specifications from Python source code.\n\n" "Paste your Python code on the left, adjust sampling parameters, " "and click **Generate Specification**." ) with gr.Row(): with gr.Column(scale=1): code_input = gr.Code( language="python", label="Python Source Code", value=EXAMPLE_CODE, lines=20, ) with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") max_tokens = gr.Slider(256, 8192, value=4096, step=256, label="Max Tokens") use_rag = gr.Checkbox(value=True, label="Use RAG (retrieve similar examples)") generate_btn = gr.Button("Generate Specification", variant="primary") with gr.Column(scale=1): spec_output = gr.Markdown(label="Generated Specification") generate_btn.click( fn=generate_spec, inputs=[code_input, temperature, top_p, max_tokens, use_rag], outputs=spec_output, ) gr.Examples( examples=[ [EXAMPLE_CODE], [ "class Stack:\n def __init__(self):\n self._items = []\n\n" " def push(self, item: Any) -> None:\n self._items.append(item)\n\n" " def pop(self) -> Any:\n if not self._items:\n" ' raise IndexError("pop from empty stack")\n' " return self._items.pop()\n\n" " def peek(self) -> Any:\n if not self._items:\n" ' raise IndexError("peek at empty stack")\n' " return self._items[-1]\n\n" " @property\n def is_empty(self) -> bool:\n" " return len(self._items) == 0\n" ], ], inputs=[code_input], label="Example Inputs", ) return app def main(): """Entry point for running the Gradio app.""" # Auto-detect adapter path for candidate in [ "./models/codewraith-lora-8b", "./models/codewraith-lora-3b", ]: if Path(candidate).exists(): print(f"Using adapter: {candidate}") model_key = "8b" if "8b" in candidate else "3b" load_model(adapter_dir=candidate, model_key=model_key) break else: print("WARNING: No adapter found. Run training first.") app = create_app() app.launch(share=True) if __name__ == "__main__": main()