Spaces:
Sleeping
Sleeping
| """Gradio application for CodeWraith inference. | |
| Provides a web interface for generating technical specifications from | |
| Python source code using the fine-tuned student model. Deployed on | |
| HuggingFace Spaces for remote access (instructor evaluation). | |
| Sampling parameters (temperature, top_p, max_tokens) are exposed | |
| as UI controls for experimentation. | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any | |
| from codewraith import SYSTEM_MESSAGE | |
| EXAMPLE_CODE = '''\ | |
| def fibonacci(n: int) -> list[int]: | |
| """Generate the first n Fibonacci numbers.""" | |
| if n <= 0: | |
| return [] | |
| sequence = [0, 1] | |
| while len(sequence) < n: | |
| sequence.append(sequence[-1] + sequence[-2]) | |
| return sequence[:n] | |
| ''' | |
| # Global model state | |
| _model = None | |
| _tokenizer = None | |
| _retriever = None | |
| def load_model( | |
| adapter_dir: str = "./models/codewraith-lora-3b", | |
| model_key: str = "3b", | |
| ) -> tuple[Any, Any]: | |
| """Load the fine-tuned model and LoRA adapter. | |
| Args: | |
| adapter_dir: Path to the LoRA adapter directory. | |
| model_key: Base model key ("3b" or "8b"). | |
| Returns: | |
| Tuple of (model, tokenizer). | |
| """ | |
| global _model, _tokenizer # noqa: PLW0603 | |
| if _model is not None: | |
| return _model, _tokenizer | |
| from peft import PeftModel | |
| from unsloth import FastLanguageModel | |
| from codewraith.student.trainer import load_base_model | |
| model, tokenizer = load_base_model(model_key) | |
| model = PeftModel.from_pretrained(model, adapter_dir) | |
| FastLanguageModel.for_inference(model) | |
| _model, _tokenizer = model, tokenizer | |
| return model, tokenizer | |
| def init_retriever() -> Any: | |
| """Initialize the RAG retriever if the index exists.""" | |
| global _retriever # noqa: PLW0603 | |
| if _retriever is not None: | |
| return _retriever | |
| try: | |
| from codewraith.app.retriever import SpecRetriever | |
| retriever = SpecRetriever() | |
| if Path("data/chromadb").exists(): | |
| collection = retriever._get_collection() | |
| if collection.count() > 0: | |
| _retriever = retriever | |
| print(f"RAG retriever loaded ({collection.count()} examples)") | |
| return _retriever | |
| except ImportError: | |
| pass | |
| return None | |
| def generate_spec( | |
| source_code: str, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| max_tokens: int = 2048, | |
| use_rag: bool = True, | |
| ) -> str: | |
| """Generate a technical specification from Python source code. | |
| Uses RAG to retrieve similar code/spec pairs as few-shot context | |
| when available, improving generation quality. | |
| Args: | |
| source_code: Python source code to analyze. | |
| temperature: Sampling temperature (higher = more creative). | |
| top_p: Nucleus sampling threshold. | |
| max_tokens: Maximum tokens to generate. | |
| use_rag: Whether to use RAG retrieval for context. | |
| Returns: | |
| Generated Markdown specification. | |
| """ | |
| if not source_code.strip(): | |
| return "*Please paste some Python source code.*" | |
| model, tokenizer = load_model() | |
| # Build user content with optional RAG context | |
| user_content = source_code | |
| if use_rag: | |
| retriever = init_retriever() | |
| if retriever is not None: | |
| examples = retriever.retrieve(source_code, n_results=3) | |
| if examples: | |
| context = retriever.format_context(examples) | |
| user_content = context + source_code | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_MESSAGE}, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| ) | |
| generated = outputs[0][inputs.shape[-1] :] | |
| return tokenizer.decode(generated, skip_special_tokens=True) | |
| def create_app(): | |
| """Create the Gradio application interface. | |
| Returns: | |
| A Gradio Blocks app ready to .launch(). | |
| """ | |
| import gradio as gr | |
| mermaid_css = """ | |
| .mermaid .node rect, | |
| .mermaid .node polygon, | |
| .mermaid .node circle { | |
| fill: #e8f0fe !important; | |
| stroke: #4a6fa5 !important; | |
| } | |
| .mermaid .nodeLabel, | |
| .mermaid .edgeLabel, | |
| .mermaid text { | |
| color: #1a1a1a !important; | |
| fill: #1a1a1a !important; | |
| } | |
| .mermaid .edgePath .path { | |
| stroke: #4a6fa5 !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="CodeWraith - Module-to-Spec Transformer", | |
| theme=gr.themes.Soft(), | |
| css=mermaid_css, | |
| ) as app: | |
| gr.Markdown( | |
| "# CodeWraith\n" | |
| "Generate technical specifications from Python source code.\n\n" | |
| "Paste your Python code on the left, adjust sampling parameters, " | |
| "and click **Generate Specification**." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| code_input = gr.Code( | |
| language="python", | |
| label="Python Source Code", | |
| value=EXAMPLE_CODE, | |
| lines=20, | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") | |
| max_tokens = gr.Slider(256, 8192, value=4096, step=256, label="Max Tokens") | |
| use_rag = gr.Checkbox(value=True, label="Use RAG (retrieve similar examples)") | |
| generate_btn = gr.Button("Generate Specification", variant="primary") | |
| with gr.Column(scale=1): | |
| spec_output = gr.Markdown(label="Generated Specification") | |
| generate_btn.click( | |
| fn=generate_spec, | |
| inputs=[code_input, temperature, top_p, max_tokens, use_rag], | |
| outputs=spec_output, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [EXAMPLE_CODE], | |
| [ | |
| "class Stack:\n def __init__(self):\n self._items = []\n\n" | |
| " def push(self, item: Any) -> None:\n self._items.append(item)\n\n" | |
| " def pop(self) -> Any:\n if not self._items:\n" | |
| ' raise IndexError("pop from empty stack")\n' | |
| " return self._items.pop()\n\n" | |
| " def peek(self) -> Any:\n if not self._items:\n" | |
| ' raise IndexError("peek at empty stack")\n' | |
| " return self._items[-1]\n\n" | |
| " @property\n def is_empty(self) -> bool:\n" | |
| " return len(self._items) == 0\n" | |
| ], | |
| ], | |
| inputs=[code_input], | |
| label="Example Inputs", | |
| ) | |
| return app | |
| def main(): | |
| """Entry point for running the Gradio app.""" | |
| # Auto-detect adapter path | |
| for candidate in [ | |
| "./models/codewraith-lora-8b", | |
| "./models/codewraith-lora-3b", | |
| ]: | |
| if Path(candidate).exists(): | |
| print(f"Using adapter: {candidate}") | |
| model_key = "8b" if "8b" in candidate else "3b" | |
| load_model(adapter_dir=candidate, model_key=model_key) | |
| break | |
| else: | |
| print("WARNING: No adapter found. Run training first.") | |
| app = create_app() | |
| app.launch(share=True) | |
| if __name__ == "__main__": | |
| main() | |