slenk commited on
Commit
cf6c23e
·
verified ·
1 Parent(s): f93e07a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +29 -72
app.py CHANGED
@@ -6,19 +6,19 @@ Set HF_REPO_ID environment variable to point to your uploaded adapter.
6
 
7
  from __future__ import annotations
8
 
9
- import json
10
  import os
11
  from pathlib import Path
12
  from typing import Any
13
 
14
  import gradio as gr
15
-
16
  import spaces
17
 
 
 
18
  # --- Config ---
19
 
20
- HF_REPO_ID = os.environ.get("HF_REPO_ID", "slenk/codewraith-lora-3b")
21
- MODEL_KEY = os.environ.get("MODEL_KEY", "3b")
22
  ADAPTER_DIR = "./adapter"
23
 
24
  MODELS = {
@@ -26,18 +26,6 @@ MODELS = {
26
  "8b": "unsloth/Llama-3.1-8B-Instruct",
27
  }
28
 
29
- # Duplicated here since spaces/app.py runs standalone on HF (can't import codewraith)
30
- SYSTEM_MESSAGE = (
31
- "You are CodeWraith, a technical specification generator. "
32
- "Given Python source code, produce a structured Markdown specification "
33
- "that accurately captures all functions, classes, parameters, return types, "
34
- "dependencies, and error handling patterns. "
35
- "Include a mermaid diagram showing the relationships between classes and functions. "
36
- "Use valid mermaid syntax with proper node IDs (no spaces or special characters in IDs). "
37
- "Example: ```mermaid\ngraph TD\n A[ModuleName] --> B[ClassName]\n"
38
- " B --> C[method_name]\n```"
39
- )
40
-
41
  EXAMPLE_CODE = '''\
42
  def fibonacci(n: int) -> list[int]:
43
  """Generate the first n Fibonacci numbers."""
@@ -106,69 +94,39 @@ def load_model() -> tuple[Any, Any]:
106
 
107
 
108
  def init_retriever():
109
- """Initialize retriever if training data is bundled."""
110
  global _retriever # noqa: PLW0603
111
 
112
  if _retriever is not None:
113
  return _retriever
114
 
115
- index_path = Path("chromadb")
116
- data_path = Path("training_pairs_clean.jsonl")
117
-
118
- if not index_path.exists() and data_path.exists():
119
- # Build index from bundled data
120
- try:
121
- import chromadb
122
- from chromadb.utils import embedding_functions
123
-
124
- client = chromadb.PersistentClient(path=str(index_path))
125
- ef = embedding_functions.SentenceTransformerEmbeddingFunction(
126
- model_name="all-MiniLM-L6-v2"
127
- )
128
- collection = client.get_or_create_collection(
129
- name="codewraith_specs", embedding_function=ef
130
- )
131
-
132
- if collection.count() == 0:
133
- pairs = []
134
- with data_path.open() as f:
135
- for line in f:
136
- if line.strip():
137
- pairs.append(json.loads(line))
138
-
139
- for i in range(0, len(pairs), 50):
140
- batch = pairs[i : i + 50]
141
- collection.add(
142
- ids=[f"pair_{i + j}" for j in range(len(batch))],
143
- documents=[p["input"] for p in batch],
144
- metadatas=[{"spec": p["output"]} for p in batch],
145
- )
146
 
147
- _retriever = (client, collection, ef)
148
- except Exception as e:
149
- print(f"RAG init failed: {e}")
 
 
 
 
 
 
150
 
151
- return _retriever
152
 
153
 
154
  def retrieve_context(source_code: str, n_results: int = 3) -> str:
155
  """Retrieve similar examples as context."""
156
- ret = init_retriever()
157
- if ret is None:
158
  return ""
159
 
160
- _, collection, _ = ret
161
- results = collection.query(query_texts=[source_code], n_results=n_results)
 
162
 
163
- parts = ["Here are examples of Python code and their specifications:\n"]
164
- for i, (doc, meta) in enumerate(zip(results["documents"][0], results["metadatas"][0]), 1):
165
- parts.append(
166
- f"\n--- Example {i} ---\n"
167
- f"Code:\n```python\n{doc[:1500]}\n```\n"
168
- f"Specification:\n{meta['spec'][:1500]}\n"
169
- )
170
- parts.append("\nNow generate a specification for the following code:\n")
171
- return "".join(parts)
172
 
173
 
174
  # --- Inference ---
@@ -206,9 +164,8 @@ def generate_spec(
206
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
207
  input_len = inputs["input_ids"].shape[-1]
208
 
209
- # Check if input is too long -- truncate RAG context if needed
210
  if input_len > 6000 and use_rag:
211
- # Retry without RAG context
212
  messages = [
213
  {"role": "system", "content": SYSTEM_MESSAGE},
214
  {"role": "user", "content": source_code},
@@ -247,7 +204,7 @@ def create_app():
247
  gr.Markdown(
248
  "# CodeWraith\n"
249
  "Generate technical specifications from Python source code.\n\n"
250
- "Paste your Python code on the left, adjust sampling parameters, "
251
  "and click **Generate Specification**."
252
  )
253
 
@@ -267,10 +224,10 @@ def create_app():
267
  clear_input_btn = gr.Button("Clear Input", variant="secondary")
268
  clear_output_btn = gr.Button("Clear Output", variant="secondary")
269
 
270
- spec_output = gr.Markdown(label="Generated Specification")
271
-
272
  gr.Markdown("*Model loads on first generation (~30s). Subsequent calls are fast.*")
273
 
 
 
274
  loading_msg = "*Generating specification... (loading model if first run)*"
275
  generate_btn.click(
276
  fn=lambda: gr.update(value=loading_msg),
@@ -294,8 +251,8 @@ def create_app():
294
  return app
295
 
296
 
297
- # Preload model on startup (before GPU decorator kicks in)
298
- print("Preloading model and adapter...")
299
  download_adapter()
300
  print("Adapter ready. Model will load on first GPU request.")
301
 
 
6
 
7
  from __future__ import annotations
8
 
 
9
  import os
10
  from pathlib import Path
11
  from typing import Any
12
 
13
  import gradio as gr
 
14
  import spaces
15
 
16
+ from codewraith import SYSTEM_MESSAGE
17
+
18
  # --- Config ---
19
 
20
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "slenk/codewraith-lora-8b")
21
+ MODEL_KEY = os.environ.get("MODEL_KEY", "8b")
22
  ADAPTER_DIR = "./adapter"
23
 
24
  MODELS = {
 
26
  "8b": "unsloth/Llama-3.1-8B-Instruct",
27
  }
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  EXAMPLE_CODE = '''\
30
  def fibonacci(n: int) -> list[int]:
31
  """Generate the first n Fibonacci numbers."""
 
94
 
95
 
96
  def init_retriever():
97
+ """Initialize retriever if ChromaDB index exists."""
98
  global _retriever # noqa: PLW0603
99
 
100
  if _retriever is not None:
101
  return _retriever
102
 
103
+ try:
104
+ from codewraith.app.retriever import SpecRetriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ retriever = SpecRetriever()
107
+ if Path("data/chromadb").exists():
108
+ collection = retriever._get_collection()
109
+ if collection.count() > 0:
110
+ _retriever = retriever
111
+ print(f"RAG retriever loaded ({collection.count()} examples)")
112
+ return _retriever
113
+ except ImportError:
114
+ pass
115
 
116
+ return None
117
 
118
 
119
  def retrieve_context(source_code: str, n_results: int = 3) -> str:
120
  """Retrieve similar examples as context."""
121
+ retriever = init_retriever()
122
+ if retriever is None:
123
  return ""
124
 
125
+ examples = retriever.retrieve(source_code, n_results=n_results)
126
+ if not examples:
127
+ return ""
128
 
129
+ return retriever.format_context(examples)
 
 
 
 
 
 
 
 
130
 
131
 
132
  # --- Inference ---
 
164
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
165
  input_len = inputs["input_ids"].shape[-1]
166
 
167
+ # Retry without RAG if input too long
168
  if input_len > 6000 and use_rag:
 
169
  messages = [
170
  {"role": "system", "content": SYSTEM_MESSAGE},
171
  {"role": "user", "content": source_code},
 
204
  gr.Markdown(
205
  "# CodeWraith\n"
206
  "Generate technical specifications from Python source code.\n\n"
207
+ "Paste your Python code below, adjust sampling parameters, "
208
  "and click **Generate Specification**."
209
  )
210
 
 
224
  clear_input_btn = gr.Button("Clear Input", variant="secondary")
225
  clear_output_btn = gr.Button("Clear Output", variant="secondary")
226
 
 
 
227
  gr.Markdown("*Model loads on first generation (~30s). Subsequent calls are fast.*")
228
 
229
+ spec_output = gr.Markdown(label="Generated Specification")
230
+
231
  loading_msg = "*Generating specification... (loading model if first run)*"
232
  generate_btn.click(
233
  fn=lambda: gr.update(value=loading_msg),
 
251
  return app
252
 
253
 
254
+ # Preload adapter on startup (CPU time, free)
255
+ print("Preloading adapter...")
256
  download_adapter()
257
  print("Adapter ready. Model will load on first GPU request.")
258