slenk commited on
Commit
7cba1fe
·
verified ·
1 Parent(s): 6db5389

Upload main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.py +255 -0
main.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio application for CodeWraith inference.
2
+
3
+ Provides a web interface for generating technical specifications from
4
+ Python source code using the fine-tuned student model. Deployed on
5
+ HuggingFace Spaces for remote access (instructor evaluation).
6
+
7
+ Sampling parameters (temperature, top_p, max_tokens) are exposed
8
+ as UI controls for experimentation.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ from codewraith import SYSTEM_MESSAGE
17
+
18
+ EXAMPLE_CODE = '''\
19
+ def fibonacci(n: int) -> list[int]:
20
+ """Generate the first n Fibonacci numbers."""
21
+ if n <= 0:
22
+ return []
23
+ sequence = [0, 1]
24
+ while len(sequence) < n:
25
+ sequence.append(sequence[-1] + sequence[-2])
26
+ return sequence[:n]
27
+ '''
28
+
29
+ # Global model state
30
+ _model = None
31
+ _tokenizer = None
32
+ _retriever = None
33
+
34
+
35
+ def load_model(
36
+ adapter_dir: str = "./models/codewraith-lora-3b",
37
+ model_key: str = "3b",
38
+ ) -> tuple[Any, Any]:
39
+ """Load the fine-tuned model and LoRA adapter.
40
+
41
+ Args:
42
+ adapter_dir: Path to the LoRA adapter directory.
43
+ model_key: Base model key ("3b" or "8b").
44
+
45
+ Returns:
46
+ Tuple of (model, tokenizer).
47
+ """
48
+ global _model, _tokenizer # noqa: PLW0603
49
+
50
+ if _model is not None:
51
+ return _model, _tokenizer
52
+
53
+ from peft import PeftModel
54
+ from unsloth import FastLanguageModel
55
+
56
+ from codewraith.student.trainer import load_base_model
57
+
58
+ model, tokenizer = load_base_model(model_key)
59
+ model = PeftModel.from_pretrained(model, adapter_dir)
60
+ FastLanguageModel.for_inference(model)
61
+
62
+ _model, _tokenizer = model, tokenizer
63
+ return model, tokenizer
64
+
65
+
66
+ def init_retriever() -> Any:
67
+ """Initialize the RAG retriever if the index exists."""
68
+ global _retriever # noqa: PLW0603
69
+
70
+ if _retriever is not None:
71
+ return _retriever
72
+
73
+ try:
74
+ from codewraith.app.retriever import SpecRetriever
75
+
76
+ retriever = SpecRetriever()
77
+ if Path("data/chromadb").exists():
78
+ collection = retriever._get_collection()
79
+ if collection.count() > 0:
80
+ _retriever = retriever
81
+ print(f"RAG retriever loaded ({collection.count()} examples)")
82
+ return _retriever
83
+ except ImportError:
84
+ pass
85
+
86
+ return None
87
+
88
+
89
+ def generate_spec(
90
+ source_code: str,
91
+ temperature: float = 0.7,
92
+ top_p: float = 0.9,
93
+ max_tokens: int = 2048,
94
+ use_rag: bool = True,
95
+ ) -> str:
96
+ """Generate a technical specification from Python source code.
97
+
98
+ Uses RAG to retrieve similar code/spec pairs as few-shot context
99
+ when available, improving generation quality.
100
+
101
+ Args:
102
+ source_code: Python source code to analyze.
103
+ temperature: Sampling temperature (higher = more creative).
104
+ top_p: Nucleus sampling threshold.
105
+ max_tokens: Maximum tokens to generate.
106
+ use_rag: Whether to use RAG retrieval for context.
107
+
108
+ Returns:
109
+ Generated Markdown specification.
110
+ """
111
+ if not source_code.strip():
112
+ return "*Please paste some Python source code.*"
113
+
114
+ model, tokenizer = load_model()
115
+
116
+ # Build user content with optional RAG context
117
+ user_content = source_code
118
+ if use_rag:
119
+ retriever = init_retriever()
120
+ if retriever is not None:
121
+ examples = retriever.retrieve(source_code, n_results=3)
122
+ if examples:
123
+ context = retriever.format_context(examples)
124
+ user_content = context + source_code
125
+
126
+ messages = [
127
+ {"role": "system", "content": SYSTEM_MESSAGE},
128
+ {"role": "user", "content": user_content},
129
+ ]
130
+
131
+ inputs = tokenizer.apply_chat_template(
132
+ messages,
133
+ tokenize=True,
134
+ add_generation_prompt=True,
135
+ return_tensors="pt",
136
+ ).to(model.device)
137
+
138
+ outputs = model.generate(
139
+ input_ids=inputs,
140
+ max_new_tokens=max_tokens,
141
+ temperature=temperature,
142
+ top_p=top_p,
143
+ do_sample=True,
144
+ )
145
+
146
+ generated = outputs[0][inputs.shape[-1] :]
147
+ return tokenizer.decode(generated, skip_special_tokens=True)
148
+
149
+
150
+ def create_app():
151
+ """Create the Gradio application interface.
152
+
153
+ Returns:
154
+ A Gradio Blocks app ready to .launch().
155
+ """
156
+ import gradio as gr
157
+
158
+ mermaid_css = """
159
+ .mermaid .node rect,
160
+ .mermaid .node polygon,
161
+ .mermaid .node circle {
162
+ fill: #e8f0fe !important;
163
+ stroke: #4a6fa5 !important;
164
+ }
165
+ .mermaid .nodeLabel,
166
+ .mermaid .edgeLabel,
167
+ .mermaid text {
168
+ color: #1a1a1a !important;
169
+ fill: #1a1a1a !important;
170
+ }
171
+ .mermaid .edgePath .path {
172
+ stroke: #4a6fa5 !important;
173
+ }
174
+ """
175
+
176
+ with gr.Blocks(
177
+ title="CodeWraith - Module-to-Spec Transformer",
178
+ theme=gr.themes.Soft(),
179
+ css=mermaid_css,
180
+ ) as app:
181
+ gr.Markdown(
182
+ "# CodeWraith\n"
183
+ "Generate technical specifications from Python source code.\n\n"
184
+ "Paste your Python code on the left, adjust sampling parameters, "
185
+ "and click **Generate Specification**."
186
+ )
187
+
188
+ with gr.Row():
189
+ with gr.Column(scale=1):
190
+ code_input = gr.Code(
191
+ language="python",
192
+ label="Python Source Code",
193
+ value=EXAMPLE_CODE,
194
+ lines=20,
195
+ )
196
+ with gr.Row():
197
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature")
198
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p")
199
+ max_tokens = gr.Slider(256, 8192, value=4096, step=256, label="Max Tokens")
200
+ use_rag = gr.Checkbox(value=True, label="Use RAG (retrieve similar examples)")
201
+ generate_btn = gr.Button("Generate Specification", variant="primary")
202
+
203
+ with gr.Column(scale=1):
204
+ spec_output = gr.Markdown(label="Generated Specification")
205
+
206
+ generate_btn.click(
207
+ fn=generate_spec,
208
+ inputs=[code_input, temperature, top_p, max_tokens, use_rag],
209
+ outputs=spec_output,
210
+ )
211
+
212
+ gr.Examples(
213
+ examples=[
214
+ [EXAMPLE_CODE],
215
+ [
216
+ "class Stack:\n def __init__(self):\n self._items = []\n\n"
217
+ " def push(self, item: Any) -> None:\n self._items.append(item)\n\n"
218
+ " def pop(self) -> Any:\n if not self._items:\n"
219
+ ' raise IndexError("pop from empty stack")\n'
220
+ " return self._items.pop()\n\n"
221
+ " def peek(self) -> Any:\n if not self._items:\n"
222
+ ' raise IndexError("peek at empty stack")\n'
223
+ " return self._items[-1]\n\n"
224
+ " @property\n def is_empty(self) -> bool:\n"
225
+ " return len(self._items) == 0\n"
226
+ ],
227
+ ],
228
+ inputs=[code_input],
229
+ label="Example Inputs",
230
+ )
231
+
232
+ return app
233
+
234
+
235
+ def main():
236
+ """Entry point for running the Gradio app."""
237
+ # Auto-detect adapter path
238
+ for candidate in [
239
+ "./models/codewraith-lora-8b",
240
+ "./models/codewraith-lora-3b",
241
+ ]:
242
+ if Path(candidate).exists():
243
+ print(f"Using adapter: {candidate}")
244
+ model_key = "8b" if "8b" in candidate else "3b"
245
+ load_model(adapter_dir=candidate, model_key=model_key)
246
+ break
247
+ else:
248
+ print("WARNING: No adapter found. Run training first.")
249
+
250
+ app = create_app()
251
+ app.launch(share=True)
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main()