File size: 9,733 Bytes
7618ac2
 
1c8f217
7618ac2
034e6b3
6ada464
7618ac2
6ada464
1c8f217
7618ac2
6ada464
034e6b3
6ada464
 
7618ac2
034e6b3
1c8f217
7618ac2
1c8f217
7618ac2
 
 
1c8f217
7618ac2
 
 
 
 
 
 
 
6ada464
 
7618ac2
 
1c8f217
034e6b3
 
7618ac2
1c8f217
034e6b3
 
 
 
 
 
 
 
 
 
 
 
6ada464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
034e6b3
 
 
6ada464
 
7618ac2
034e6b3
7618ac2
 
1c8f217
7618ac2
034e6b3
7618ac2
 
1c8f217
7618ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c8f217
7618ac2
 
 
 
1c8f217
034e6b3
 
 
6ada464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
034e6b3
 
 
 
6ada464
 
034e6b3
 
7618ac2
034e6b3
 
 
 
 
 
 
 
6ada464
 
034e6b3
7618ac2
034e6b3
 
ea59ffe
034e6b3
 
 
 
 
ea59ffe
034e6b3
 
 
 
 
6ada464
7618ac2
034e6b3
 
6ada464
 
 
 
 
034e6b3
 
 
 
6ada464
034e6b3
 
 
 
 
 
 
 
 
 
6ada464
 
 
 
 
7618ac2
 
034e6b3
7618ac2
034e6b3
7618ac2
034e6b3
7618ac2
1c8f217
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import torch
import gradio as gr
import spaces
import json
import random  # <<< NEW
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

# >>>> CHANGE THIS <<<<
MODEL_ID = os.getenv("MODEL_ID", "theostos/LLM4Docq-annotator-fp8")
RESULT_JSON_PATH = os.getenv("RESULT_JSON_PATH", "result.json")
# NEW: path to your test set (list[dict])
TEST_JSON_PATH = os.getenv("TEST_JSON_PATH", "test.json")
# Matches your training style: messages=[{"role":"user","content": template.format(term=..., dependencies=...)}]
INSTRUCTION_TEMPLATE = "You are given a Coq source file along with an optional prefix.\n\n- The **prefix** contains lines that appear *before* the current chunk of code. It provides contextual information to help you understand the surrounding definitions, imports, and notation.\n- The **source** contains the chunk of code you must annotate and complete.\n\nSome parts of the code contain special placeholders:\n\n- [PREDICT_DOCSTRING]: This placeholder appears before an element. You must replace it with a descriptive comment (in Coq comment syntax (* ... *)) that explains what the element does.\n\n- [PREDICT_STATEMENT]: This placeholder appears after an explanatory comment. You must replace it with a valid Coq statement or definition that matches the meaning of the preceding comment.\n\nYour task is to rewrite the entire Coq source chunk, replacing all placeholders with appropriate content, while preserving all other parts of the source code exactly as they are.\n\n### Guidelines\n1. The **prefix** is only provided for context — do **not** modify it or include it in your output.\n2. Rewrite only the **source** content.\n3. Keep all existing Coq syntax, imports, and formatting intact.\n4. Replace [PREDICT_DOCSTRING] with a natural-language description of the next element.\n5. Replace [PREDICT_STATEMENT] with a complete and syntactically correct Coq statement (definition, lemma, theorem, etc.) that corresponds to the immediately preceding comment.\n6. Ensure the generated statements are consistent with the style and logic suggested by the prefix and surrounding code.\n7. Do not add or remove any lines except to substitute the placeholders.\n\n### Output format\nReturn **only** the full rewritten Coq source chunk (without the prefix), with all placeholders replaced.\n\nHere is the context and source:\n\n## Prefix:\n{prefix}\n\n## Source:\n{source}"

HF_TOKEN = os.getenv("HF_TOKEN")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

_model = None
def load_model():
    global _model
    if _model is None:
        _model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            device_map="auto",
            dtype="auto",
            trust_remote_code=True
        )
    return _model

def build_messages(prefix: str, source: str):
    content = INSTRUCTION_TEMPLATE.format(prefix=prefix, source=source)
    return [{"role": "user", "content": content}]

def load_prefixes(path=RESULT_JSON_PATH):
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if not isinstance(data, dict):
            raise ValueError("result.json must be a JSON object mapping keys -> prefix strings.")
        # coerce to str->str
        return {str(k): str(v) for k, v in data.items()}
    except Exception as e:
        print(f"[warn] Could not load {path}: {e}")
        return {}

# --- NEW: test set loader + helpers ---
def load_test_examples(path=TEST_JSON_PATH):
    """
    Expects a JSON list of dicts with keys:
      - 'prefix'
      - 'partially_annotated_last_target'
      - 'fully_annotated_last_target'
    Returns a cleaned list with {'prefix','target','truth'} strings.
    """
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        if not isinstance(data, list):
            raise ValueError("Test set JSON must be a list of objects.")
        cleaned = []
        for i, ex in enumerate(data):
            if not isinstance(ex, dict):
                continue
            prefix = str(ex.get("prefix", ""))
            target = str(ex.get("partially_annotated_last_target", ""))
            truth  = str(ex.get("fully_annotated_last_target", ""))
            cleaned.append({"prefix": prefix, "target": target, "truth": truth})
        print(f"[info] Loaded {len(cleaned)} test examples from {path}")
        return cleaned
    except Exception as e:
        print(f"[warn] Could not load test set {path}: {e}")
        return []

PREFIXES = load_prefixes()
PREFIX_KEYS = sorted(PREFIXES.keys())

TEST_EXAMPLES = load_test_examples()  # <<< NEW

# Estimate duration for ZeroGPU (default is 60s). Shorter = better queue priority.
def _duration(term, deps, temperature, top_p, max_new_tokens):
    # crude: ~2.5 tok/s + 30s headroom
    return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))

@spaces.GPU(duration=_duration)
def generate(term, deps, temperature, top_p, max_new_tokens):
    model = load_model()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    messages = build_messages(term, deps)
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(
        inputs=inputs,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        do_sample=True,
        streamer=streamer,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    out = ""
    for token in streamer:  # stream tokens to UI
        out += token
        yield f"```rocq\n{out}\n```"

def set_prefix_from_key(key: str) -> str:
    return PREFIXES.get(key, "") if key else ""

# NEW: sample a random test example
def _sample_test_example():
    if not TEST_EXAMPLES:
        # Return empty prefix/target and a notice in the truth box
        return "", "", "No test examples loaded. Set TEST_JSON_PATH or add test.json at repo root."
    ex = random.choice(TEST_EXAMPLES)
    truth_md = f"```rocq\n{ex['truth']}\n```" if ex["truth"] else ""
    return ex["prefix"], ex["target"], truth_md

# NEW: hot-reload the test set
def _reload_test_set():
    global TEST_EXAMPLES
    TEST_EXAMPLES = load_test_examples()
    return gr.update(value=f"Reloaded {len(TEST_EXAMPLES)} test examples from {TEST_JSON_PATH}.")

with gr.Blocks(title="Rocq Annotator (ZeroGPU, FP8)") as demo:
    gr.Markdown(
        "# Rocq annotator\n"
        "Pick a **prefix** example from the dropdown to auto-fill the Prefix editor, "
        "then write a **target snippet** (with [PREDICT_STATEMENT]/[PREDICT_DOCSTRING] tags) and click **Annotate**.\n\n"
        "You can also use **🎲 Draw test example** to pull a sample from the test set; the **Baseline (truth)** panel shows the expected annotated result."
    )

    with gr.Row():
        dropdown = gr.Dropdown(
            choices=PREFIX_KEYS,
            label="Choose a prefix example (from result.json)",
            allow_custom_value=False,
            value=None,
        )

        reload_btn = gr.Button("Reload result.json", variant="secondary")
        sample_btn = gr.Button("🎲 Draw test example", variant="secondary")
        reload_test_btn = gr.Button("Reload test set", variant="secondary")

    with gr.Row():
        prefix_box = gr.Code(
            label="Prefix (context; auto-filled from dropdown, then editable)",
            language=None,
            interactive=True,
            lines=18,
        )
        target_box = gr.Code(
            label="Target snippet (contains [PREDICT_STATEMENT] / [PREDICT_DOCSTRING])",
            language=None,
            interactive=True,
            lines=18,
        )

    with gr.Row():
        temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
        max_new = gr.Slider(32, 512, value=128, step=32, label="max_new_tokens")

    # Output panels: model vs baseline/truth
    with gr.Row():
        out = gr.Markdown(label="Annotated Rocq")
        truth_md = gr.Markdown(label="Baseline (truth)")  # <<< NEW

    btn = gr.Button("Annotate", variant="primary")

    # --- wiring ---
    dropdown.change(set_prefix_from_key, inputs=dropdown, outputs=prefix_box)

    # Optional: hot reload result.json without restarting Space
    def _reload():
        global PREFIXES, PREFIX_KEYS
        PREFIXES = load_prefixes()
        PREFIX_KEYS = sorted(PREFIXES.keys())
        # return updated dropdown (choices) and a notice
        return gr.update(choices=PREFIX_KEYS), gr.update(value="Reloaded result.json.")
    notice = gr.Markdown("")
    reload_btn.click(_reload, inputs=None, outputs=[dropdown, notice])

    # NEW: wire test set actions
    test_notice = gr.Markdown("")
    sample_btn.click(_sample_test_example, inputs=None, outputs=[prefix_box, target_box, truth_md])
    reload_test_btn.click(_reload_test_set, inputs=None, outputs=test_notice)

    btn.click(
        generate,
        inputs=[prefix_box, target_box, temperature, top_p, max_new],
        outputs=out,
        concurrency_limit=1,
    )

    demo.queue(max_size=20, default_concurrency_limit=1)

if __name__ == "__main__":
    demo.launch()