codewraith / main.py
slenk's picture
Upload main.py with huggingface_hub
7cba1fe verified
"""Gradio application for CodeWraith inference.
Provides a web interface for generating technical specifications from
Python source code using the fine-tuned student model. Deployed on
HuggingFace Spaces for remote access (instructor evaluation).
Sampling parameters (temperature, top_p, max_tokens) are exposed
as UI controls for experimentation.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from codewraith import SYSTEM_MESSAGE
EXAMPLE_CODE = '''\
def fibonacci(n: int) -> list[int]:
"""Generate the first n Fibonacci numbers."""
if n <= 0:
return []
sequence = [0, 1]
while len(sequence) < n:
sequence.append(sequence[-1] + sequence[-2])
return sequence[:n]
'''
# Global model state
_model = None
_tokenizer = None
_retriever = None
def load_model(
adapter_dir: str = "./models/codewraith-lora-3b",
model_key: str = "3b",
) -> tuple[Any, Any]:
"""Load the fine-tuned model and LoRA adapter.
Args:
adapter_dir: Path to the LoRA adapter directory.
model_key: Base model key ("3b" or "8b").
Returns:
Tuple of (model, tokenizer).
"""
global _model, _tokenizer # noqa: PLW0603
if _model is not None:
return _model, _tokenizer
from peft import PeftModel
from unsloth import FastLanguageModel
from codewraith.student.trainer import load_base_model
model, tokenizer = load_base_model(model_key)
model = PeftModel.from_pretrained(model, adapter_dir)
FastLanguageModel.for_inference(model)
_model, _tokenizer = model, tokenizer
return model, tokenizer
def init_retriever() -> Any:
"""Initialize the RAG retriever if the index exists."""
global _retriever # noqa: PLW0603
if _retriever is not None:
return _retriever
try:
from codewraith.app.retriever import SpecRetriever
retriever = SpecRetriever()
if Path("data/chromadb").exists():
collection = retriever._get_collection()
if collection.count() > 0:
_retriever = retriever
print(f"RAG retriever loaded ({collection.count()} examples)")
return _retriever
except ImportError:
pass
return None
def generate_spec(
source_code: str,
temperature: float = 0.7,
top_p: float = 0.9,
max_tokens: int = 2048,
use_rag: bool = True,
) -> str:
"""Generate a technical specification from Python source code.
Uses RAG to retrieve similar code/spec pairs as few-shot context
when available, improving generation quality.
Args:
source_code: Python source code to analyze.
temperature: Sampling temperature (higher = more creative).
top_p: Nucleus sampling threshold.
max_tokens: Maximum tokens to generate.
use_rag: Whether to use RAG retrieval for context.
Returns:
Generated Markdown specification.
"""
if not source_code.strip():
return "*Please paste some Python source code.*"
model, tokenizer = load_model()
# Build user content with optional RAG context
user_content = source_code
if use_rag:
retriever = init_retriever()
if retriever is not None:
examples = retriever.retrieve(source_code, n_results=3)
if examples:
context = retriever.format_context(examples)
user_content = context + source_code
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": user_content},
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
input_ids=inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
generated = outputs[0][inputs.shape[-1] :]
return tokenizer.decode(generated, skip_special_tokens=True)
def create_app():
"""Create the Gradio application interface.
Returns:
A Gradio Blocks app ready to .launch().
"""
import gradio as gr
mermaid_css = """
.mermaid .node rect,
.mermaid .node polygon,
.mermaid .node circle {
fill: #e8f0fe !important;
stroke: #4a6fa5 !important;
}
.mermaid .nodeLabel,
.mermaid .edgeLabel,
.mermaid text {
color: #1a1a1a !important;
fill: #1a1a1a !important;
}
.mermaid .edgePath .path {
stroke: #4a6fa5 !important;
}
"""
with gr.Blocks(
title="CodeWraith - Module-to-Spec Transformer",
theme=gr.themes.Soft(),
css=mermaid_css,
) as app:
gr.Markdown(
"# CodeWraith\n"
"Generate technical specifications from Python source code.\n\n"
"Paste your Python code on the left, adjust sampling parameters, "
"and click **Generate Specification**."
)
with gr.Row():
with gr.Column(scale=1):
code_input = gr.Code(
language="python",
label="Python Source Code",
value=EXAMPLE_CODE,
lines=20,
)
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p")
max_tokens = gr.Slider(256, 8192, value=4096, step=256, label="Max Tokens")
use_rag = gr.Checkbox(value=True, label="Use RAG (retrieve similar examples)")
generate_btn = gr.Button("Generate Specification", variant="primary")
with gr.Column(scale=1):
spec_output = gr.Markdown(label="Generated Specification")
generate_btn.click(
fn=generate_spec,
inputs=[code_input, temperature, top_p, max_tokens, use_rag],
outputs=spec_output,
)
gr.Examples(
examples=[
[EXAMPLE_CODE],
[
"class Stack:\n def __init__(self):\n self._items = []\n\n"
" def push(self, item: Any) -> None:\n self._items.append(item)\n\n"
" def pop(self) -> Any:\n if not self._items:\n"
' raise IndexError("pop from empty stack")\n'
" return self._items.pop()\n\n"
" def peek(self) -> Any:\n if not self._items:\n"
' raise IndexError("peek at empty stack")\n'
" return self._items[-1]\n\n"
" @property\n def is_empty(self) -> bool:\n"
" return len(self._items) == 0\n"
],
],
inputs=[code_input],
label="Example Inputs",
)
return app
def main():
"""Entry point for running the Gradio app."""
# Auto-detect adapter path
for candidate in [
"./models/codewraith-lora-8b",
"./models/codewraith-lora-3b",
]:
if Path(candidate).exists():
print(f"Using adapter: {candidate}")
model_key = "8b" if "8b" in candidate else "3b"
load_model(adapter_dir=candidate, model_key=model_key)
break
else:
print("WARNING: No adapter found. Run training first.")
app = create_app()
app.launch(share=True)
if __name__ == "__main__":
main()