theostos commited on
Commit
034e6b3
·
1 Parent(s): 7618ac2

Add template of prefix, update prompt

Browse files
Files changed (2) hide show
  1. app.py +76 -19
  2. result.json +0 -0
app.py CHANGED
@@ -2,18 +2,15 @@ import os
2
  import torch
3
  import gradio as gr
4
  import spaces
 
5
  from threading import Thread
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, FineGrainedFP8Config, TextIteratorStreamer
7
 
8
  # >>>> CHANGE THIS <<<<
9
  MODEL_ID = os.getenv("MODEL_ID", "theostos/LLM4Docq-annotator")
10
-
11
  # Matches your training style: messages=[{"role":"user","content": template.format(term=..., dependencies=...)}]
12
- INSTRUCTION_TEMPLATE = (
13
- "You are a Rocq code annotator. Given the Coq term and its dependencies, "
14
- "produce helpful inline comments and explanations.\n\n"
15
- "Term:\n{term}\n\nDependencies:\n{dependencies}\n"
16
- )
17
 
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
@@ -37,17 +34,32 @@ def load_model():
37
  )
38
  return _model
39
 
40
- def build_messages(term: str, deps: str):
41
- content = INSTRUCTION_TEMPLATE.format(term=term, dependencies=deps)
42
  return [{"role": "user", "content": content}]
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # Estimate duration for ZeroGPU (default is 60s). Shorter = better queue priority.
45
- def _duration(term, deps, temperature, top_p, max_new_tokens, repetition_penalty):
46
  # crude: ~2.5 tok/s + 30s headroom
47
  return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))
48
 
49
  @spaces.GPU(duration=_duration)
50
- def generate(term, deps, temperature, top_p, max_new_tokens, repetition_penalty):
51
  model = load_model()
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
 
@@ -79,23 +91,68 @@ def generate(term, deps, temperature, top_p, max_new_tokens, repetition_penalty)
79
  out += token
80
  yield f"```rocq\n{out}\n```"
81
 
82
- with gr.Blocks(title="Rocq Annotator (ZeroGPU)") as demo:
83
- gr.Markdown("# Rocq annotator\nThe model will produce annotated Rocq code.")
 
 
 
 
 
 
 
 
84
  with gr.Row():
85
- term = gr.Textbox(label="Prefix", lines=100, placeholder="Paste the prefix to use")
86
- deps = gr.Textbox(label="To annotate", lines=8, placeholder="The code to annotate")
 
 
 
 
 
 
 
87
  with gr.Row():
88
- temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
90
- max_new = gr.Slider(256, 8192, value=4096, step=32, label="max_new_tokens")
 
91
  out = gr.Markdown(label="Annotated Rocq")
92
- btn = gr.Button("Annotate")
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  btn.click(
94
  generate,
95
- inputs=[term, deps, temperature, top_p, max_new],
96
  outputs=out,
97
- concurrency_limit=1, # cooperate with ZeroGPU queues
98
  )
 
99
  demo.queue(max_size=20, default_concurrency_limit=1)
100
 
101
  if __name__ == "__main__":
 
2
  import torch
3
  import gradio as gr
4
  import spaces
5
+ import json
6
  from threading import Thread
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, FineGrainedFP8Config, TextIteratorStreamer
8
 
9
  # >>>> CHANGE THIS <<<<
10
  MODEL_ID = os.getenv("MODEL_ID", "theostos/LLM4Docq-annotator")
11
+ RESULT_JSON_PATH = os.getenv("RESULT_JSON_PATH", "result.json")
12
  # Matches your training style: messages=[{"role":"user","content": template.format(term=..., dependencies=...)}]
13
+ INSTRUCTION_TEMPLATE = "You are given a Coq source file along with an optional prefix.\n\n- The **prefix** contains lines that appear *before* the current chunk of code. It provides contextual information to help you understand the surrounding definitions, imports, and notation.\n- The **source** contains the chunk of code you must annotate and complete.\n\nSome parts of the code contain special placeholders:\n\n- [PREDICT_DOCSTRING]: This placeholder appears before an element. You must replace it with a descriptive comment (in Coq comment syntax (* ... *)) that explains what the element does.\n\n- [PREDICT_STATEMENT]: This placeholder appears after an explanatory comment. You must replace it with a valid Coq statement or definition that matches the meaning of the preceding comment.\n\nYour task is to rewrite the entire Coq source chunk, replacing all placeholders with appropriate content, while preserving all other parts of the source code exactly as they are.\n\n### Guidelines\n1. The **prefix** is only provided for context — do **not** modify it or include it in your output.\n2. Rewrite only the **source** content.\n3. Keep all existing Coq syntax, imports, and formatting intact.\n4. Replace [PREDICT_DOCSTRING] with a natural-language description of the next element.\n5. Replace [PREDICT_STATEMENT] with a complete and syntactically correct Coq statement (definition, lemma, theorem, etc.) that corresponds to the immediately preceding comment.\n6. Ensure the generated statements are consistent with the style and logic suggested by the prefix and surrounding code.\n7. Do not add or remove any lines except to substitute the placeholders.\n\n### Output format\nReturn **only** the full rewritten Coq source chunk (without the prefix), with all placeholders replaced.\n\nHere is the context and source:\n\n## Prefix:\n{prefix}\n\n## Source:\n{source}"
 
 
 
 
14
 
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
 
 
34
  )
35
  return _model
36
 
37
+ def build_messages(prefix: str, source: str):
38
+ content = INSTRUCTION_TEMPLATE.format(prefix=prefix, source=source)
39
  return [{"role": "user", "content": content}]
40
 
41
+ def load_prefixes(path=RESULT_JSON_PATH):
42
+ try:
43
+ with open(path, "r", encoding="utf-8") as f:
44
+ data = json.load(f)
45
+ if not isinstance(data, dict):
46
+ raise ValueError("result.json must be a JSON object mapping keys -> prefix strings.")
47
+ # coerce to str->str
48
+ return {str(k): str(v) for k, v in data.items()}
49
+ except Exception as e:
50
+ print(f"[warn] Could not load {path}: {e}")
51
+ return {}
52
+
53
+ PREFIXES = load_prefixes()
54
+ PREFIX_KEYS = sorted(PREFIXES.keys())
55
+
56
  # Estimate duration for ZeroGPU (default is 60s). Shorter = better queue priority.
57
+ def _duration(term, deps, temperature, top_p, max_new_tokens):
58
  # crude: ~2.5 tok/s + 30s headroom
59
  return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))
60
 
61
  @spaces.GPU(duration=_duration)
62
+ def generate(term, deps, temperature, top_p, max_new_tokens):
63
  model = load_model()
64
  device = "cuda" if torch.cuda.is_available() else "cpu"
65
 
 
91
  out += token
92
  yield f"```rocq\n{out}\n```"
93
 
94
+ def set_prefix_from_key(key: str) -> str:
95
+ return PREFIXES.get(key, "") if key else ""
96
+
97
+ with gr.Blocks(title="Rocq Annotator (ZeroGPU, FP8)") as demo:
98
+ gr.Markdown(
99
+ "# Rocq annotator\n"
100
+ "Pick a **prefix** example from the dropdown to auto-fill the Prefix editor, "
101
+ "then write a **target snippet** (with [PREDICT_STATEMENT]/[PREDICT_DOCSTRING] tags) and click **Annotate**."
102
+ )
103
+
104
  with gr.Row():
105
+ dropdown = gr.Dropdown(
106
+ choices=PREFIX_KEYS,
107
+ label="Choose a prefix example (from result.json)",
108
+ allow_custom_value=False,
109
+ value=None,
110
+ )
111
+
112
+ reload_btn = gr.Button("Reload result.json", variant="secondary")
113
+
114
  with gr.Row():
115
+ prefix_box = gr.Code(
116
+ label="Prefix (context; auto-filled from dropdown, then editable)",
117
+ language="coq",
118
+ interactive=True,
119
+ lines=18,
120
+ )
121
+ target_box = gr.Code(
122
+ label="Target snippet (contains [PREDICT_STATEMENT] / [PREDICT_DOCSTRING])",
123
+ language="coq",
124
+ interactive=True,
125
+ lines=18,
126
+ placeholder="Paste the code to annotate…",
127
+ )
128
+
129
+ with gr.Row():
130
+ temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
131
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
132
+ max_new = gr.Slider(32, 512, value=128, step=32, label="max_new_tokens")
133
+
134
  out = gr.Markdown(label="Annotated Rocq")
135
+ btn = gr.Button("Annotate", variant="primary")
136
+
137
+ # --- wiring ---
138
+ dropdown.change(set_prefix_from_key, inputs=dropdown, outputs=prefix_box)
139
+ # Optional: hot reload result.json without restarting Space
140
+ def _reload():
141
+ global PREFIXES, PREFIX_KEYS
142
+ PREFIXES = load_prefixes()
143
+ PREFIX_KEYS = sorted(PREFIXES.keys())
144
+ # return updated dropdown (choices) and a notice
145
+ return gr.update(choices=PREFIX_KEYS), gr.update(value="Reloaded result.json.")
146
+ notice = gr.Markdown("")
147
+ reload_btn.click(_reload, inputs=None, outputs=[dropdown, notice])
148
+
149
  btn.click(
150
  generate,
151
+ inputs=[prefix_box, target_box, temperature, top_p, max_new],
152
  outputs=out,
153
+ concurrency_limit=1,
154
  )
155
+
156
  demo.queue(max_size=20, default_concurrency_limit=1)
157
 
158
  if __name__ == "__main__":
result.json ADDED
The diff for this file is too large to render. See raw diff