Spaces:
Sleeping
Sleeping
| """CodeWraith HuggingFace Spaces entry point. | |
| Downloads the LoRA adapter from HF Hub and serves the Gradio interface. | |
| Set HF_REPO_ID environment variable to point to your uploaded adapter. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| # Ensure src/ is importable (HF Spaces runs app.py directly) | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) | |
| # Force pandas to fully initialize before transitive imports cause circular import | |
| import pandas # noqa: F401, I001 | |
| import gradio as gr | |
| import spaces | |
| from codewraith import SYSTEM_MESSAGE | |
| # --- Config --- | |
| HF_REPO_ID = os.environ.get("HF_REPO_ID", "slenk/codewraith-lora-8b") | |
| MODEL_KEY = os.environ.get("MODEL_KEY", "8b") | |
| ADAPTER_DIR = "./adapter" | |
| MODELS = { | |
| "3b": "unsloth/Llama-3.2-3B-Instruct", | |
| "8b": "unsloth/Llama-3.1-8B-Instruct", | |
| } | |
| EXAMPLE_CODE = '''\ | |
| def fibonacci(n: int) -> list[int]: | |
| """Generate the first n Fibonacci numbers.""" | |
| if n <= 0: | |
| return [] | |
| sequence = [0, 1] | |
| while len(sequence) < n: | |
| sequence.append(sequence[-1] + sequence[-2]) | |
| return sequence[:n] | |
| ''' | |
| # --- Global state --- | |
| _model = None | |
| _tokenizer = None | |
| _retriever = None | |
| # --- Model loading --- | |
| def download_adapter(): | |
| """Download the LoRA adapter from HF Hub if not already cached.""" | |
| if Path(ADAPTER_DIR).exists() and any(Path(ADAPTER_DIR).iterdir()): | |
| print(f"Adapter already cached at {ADAPTER_DIR}") | |
| return | |
| from huggingface_hub import snapshot_download | |
| print(f"Downloading adapter from {HF_REPO_ID}...") | |
| snapshot_download(repo_id=HF_REPO_ID, local_dir=ADAPTER_DIR) | |
| print("Download complete.") | |
| def load_model() -> tuple[Any, Any]: | |
| """Load the base model with LoRA adapter.""" | |
| global _model, _tokenizer # noqa: PLW0603 | |
| if _model is not None: | |
| return _model, _tokenizer | |
| download_adapter() | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| model_name = MODELS[MODEL_KEY] | |
| print(f"Loading {model_name}...") | |
| bnb_config = BitsAndBytesConfig(load_in_4bit=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(model, ADAPTER_DIR) | |
| model.eval() | |
| _model, _tokenizer = model, tokenizer | |
| return model, tokenizer | |
| # --- RAG --- | |
| def init_retriever(): | |
| """Initialize retriever if ChromaDB index exists.""" | |
| global _retriever # noqa: PLW0603 | |
| if _retriever is not None: | |
| return _retriever | |
| try: | |
| from codewraith.app.retriever import SpecRetriever | |
| retriever = SpecRetriever() | |
| if Path("data/chromadb").exists(): | |
| collection = retriever._get_collection() | |
| if collection.count() > 0: | |
| _retriever = retriever | |
| print(f"RAG retriever loaded ({collection.count()} examples)") | |
| return _retriever | |
| except ImportError: | |
| pass | |
| return None | |
| def retrieve_context(source_code: str, n_results: int = 3) -> str: | |
| """Retrieve similar examples as context.""" | |
| retriever = init_retriever() | |
| if retriever is None: | |
| return "" | |
| examples = retriever.retrieve(source_code, n_results=n_results) | |
| if not examples: | |
| return "" | |
| return retriever.format_context(examples) | |
| # --- Inference --- | |
| def generate_spec( | |
| source_code: str, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| max_tokens: int = 2048, | |
| use_rag: bool = True, | |
| ) -> str: | |
| """Generate a technical specification.""" | |
| if not source_code.strip(): | |
| return "*Please paste some Python source code.*" | |
| try: | |
| model, tokenizer = load_model() | |
| user_content = source_code | |
| if use_rag: | |
| context = retrieve_context(source_code) | |
| if context: | |
| user_content = context + source_code | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_MESSAGE}, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| # Retry without RAG if input too long | |
| if input_len > 6000 and use_rag: | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_MESSAGE}, | |
| {"role": "user", "content": source_code}, | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| ) | |
| generated = outputs[0][input_len:] | |
| spec = tokenizer.decode(generated, skip_special_tokens=True) | |
| return render_mermaid_images(spec) | |
| except Exception as e: | |
| return ( | |
| f"**Error generating specification:**\n\n```\n{e}\n```\n\n" | |
| "Try with a shorter input or disable RAG." | |
| ) | |
| # --- Gradio UI --- | |
| def render_mermaid_images(spec: str) -> str: | |
| """Replace mermaid code blocks with rendered SVG images via mermaid.ink. | |
| Validates the mermaid syntax first, strips malformed blocks, and | |
| converts valid ones to inline images that render reliably regardless | |
| of CSS/theme issues. | |
| """ | |
| import base64 | |
| import re | |
| valid_starts = ( | |
| "graph ", | |
| "graph\n", | |
| "flowchart ", | |
| "flowchart\n", | |
| "classDiagram", | |
| "sequenceDiagram", | |
| "stateDiagram", | |
| "erDiagram", | |
| "gantt", | |
| "pie", | |
| "gitgraph", | |
| ) | |
| def replace_block(match: re.Match) -> str: | |
| block = match.group(1).strip() | |
| # Must start with a valid diagram type | |
| if not any(block.startswith(s) for s in valid_starts): | |
| return "*[Mermaid diagram removed: unrecognized diagram type]*" | |
| # Check balanced brackets/braces | |
| if block.count("[") != block.count("]"): | |
| return "*[Mermaid diagram removed: unbalanced brackets]*" | |
| if block.count("{") != block.count("}"): | |
| return "*[Mermaid diagram removed: unbalanced braces]*" | |
| if block.count("(") != block.count(")"): | |
| return "*[Mermaid diagram removed: unbalanced parentheses]*" | |
| # Encode and return as mermaid.ink image | |
| encoded = base64.urlsafe_b64encode(block.encode("utf-8")).decode("ascii") | |
| url = f"https://mermaid.ink/svg/{encoded}" | |
| return f'<img src="{url}" alt="Dependency Diagram" style="max-width: 600px;">' | |
| return re.sub(r"```mermaid\s*\n(.*?)```", replace_block, spec, flags=re.DOTALL) | |
| def create_app(): | |
| with gr.Blocks( | |
| title="CodeWraith - Module-to-Spec Transformer", | |
| ) as app: | |
| gr.Markdown( | |
| "# CodeWraith\n" | |
| "Generate technical specifications from Python source code.\n\n" | |
| "Paste your Python code below, adjust sampling parameters, " | |
| "and click **Generate Specification**." | |
| ) | |
| code_input = gr.Code( | |
| language="python", | |
| label="Python Source Code", | |
| value=EXAMPLE_CODE, | |
| lines=15, | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") | |
| max_tokens = gr.Slider(256, 8192, value=4096, step=256, label="Max Tokens") | |
| with gr.Row(): | |
| use_rag = gr.Checkbox(value=True, label="Use RAG (retrieve similar examples)") | |
| generate_btn = gr.Button("Generate Specification", variant="primary") | |
| clear_input_btn = gr.Button("Clear Input", variant="secondary") | |
| clear_output_btn = gr.Button("Clear Output", variant="secondary") | |
| gr.Markdown("*Model loads on first generation (~30s). Subsequent calls are fast.*") | |
| spec_output = gr.Markdown(label="Generated Specification") | |
| loading_msg = "*Generating specification... (loading model if first run)*" | |
| generate_btn.click( | |
| fn=lambda: gr.update(value=loading_msg), | |
| outputs=spec_output, | |
| ).then( | |
| fn=generate_spec, | |
| inputs=[code_input, temperature, top_p, max_tokens, use_rag], | |
| outputs=spec_output, | |
| ) | |
| clear_input_btn.click( | |
| fn=lambda: "", | |
| outputs=code_input, | |
| ) | |
| clear_output_btn.click( | |
| fn=lambda: "", | |
| outputs=spec_output, | |
| ) | |
| return app | |
| # Preload adapter on startup (CPU time, free) | |
| print("Preloading adapter...") | |
| download_adapter() | |
| print("Adapter ready. Model will load on first GPU request.") | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.launch() | |