theostos's picture
Fix issue with initial_goal not being correctly joined, fix inputs not being expanded correctly into gen_kwarg
0de4dee
import os
import torch
import gradio as gr
import spaces
import json
import random # <<< NEW
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# >>>> CHANGE THIS <<<<
MODEL_ID = os.getenv("MODEL_ID", "theostos/babel-translate-fp8")
TEST_JSON_PATH = os.getenv("TEST_JSON_PATH", "test.json")
INSTRUCTION_TEMPLATE = {
"Rocq": "You are given a proof term:\n\n{term}\n\nYour task is to derive a sequence of Coq tactics that corresponds to this term.\n\nWhen you work through the problem, write down your reasoning in detail inside <think> ... </think> tags. This reasoning should reflect your natural thought process as you explore the structure of the term and figure out what tactics to apply. You should consider different possible approaches, reflect on why some might or might not work, and gradually converge on a tactic choice.\n\nAfter each reasoning block, provide the next (group of) tactic(s) enclosed in:\n\n\\box{{\n <tactic>\n}}\n\nSome dependencies that could be helpful:\n\n{dependencies}",
"Lean 4": "You are given a proof term:\n\n{term}\n\nYour task is to derive a sequence of Lean 4 tactics that corresponds to this term.\n\nWhen you work through the problem, write down your reasoning in detail inside <think> ... </think> tags. This reasoning should reflect your natural thought process as you explore the structure of the term and figure out what tactics to apply. You should consider different possible approaches, reflect on why some might or might not work, and gradually converge on a tactic choice.\n\nAfter each reasoning block, provide the next (group of) tactic(s) enclosed in:\n\n\\box{{\n <tactic>\n}}\n\nMake sure your proof script respects Lean indentation rules. For example:\n\n\\box{{\n classical\n refine lemma_a\n}}\n\nSome dependencies that could be helpful:\n\n{dependencies}"
}
REASONING_TEMPLATE = {
"Rocq": "<think> Okay, let\'s try to transform this proof term into a sequence of coq tactics. First let\'s write down the hypotheses, and the initial goal (after the \"|-\" symbol) given by the coq proof assistant:\n{initial_goals}.",
"Lean 4": "<think> Okay, let\'s try to transform this proof term into a sequence of lean 4 tactics. I need to be careful with indentation when I write the steps into the \\box{{}}. First let\'s write down the hypotheses, and the initial goal given by the Lean proof assistant:\n{initial_goals}."
}
HF_TOKEN = os.getenv("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
_model = None
def load_model():
global _model
if _model is None:
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
device_map="auto",
dtype="auto",
trust_remote_code=True
)
return _model
def build_messages(term: str, deps: str, initial_goals: str, language: str):
instr = INSTRUCTION_TEMPLATE[language].format(term=term, dependencies=deps)
reasoning = REASONING_TEMPLATE[language].format(initial_goals=initial_goals)
return [{"role": "user", "content": instr}], reasoning
def load_test_examples(path=TEST_JSON_PATH):
"""
Expects a JSON list of dicts with keys:
- 'lean'
- 'term'
- 'dependencies'
- 'initial_goal'
- 'rocq'
- 'term'
- 'dependencies'
- 'initial_goal'
"""
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Test set JSON must be a list of objects.")
print(f"[info] Loaded {len(data)} test examples from {path}")
return data
except Exception as e:
print(f"[warn] Could not load test set {path}: {e}")
return []
TEST_EXAMPLES = load_test_examples()
def _duration(term, deps, initial_goals, language, temperature, top_p, max_new_tokens):
return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))
@spaces.GPU(duration=_duration)
def generate(term, deps, initial_goals, language, temperature, top_p, max_new_tokens):
model = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
messages, think_initial = build_messages(term, deps, initial_goals, language)
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt_text + think_initial, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
out = ""
for token in streamer: # stream tokens to UI
out += token
yield f"```rocq\n{out}\n```"
def _sample_test_example(target):
if not TEST_EXAMPLES:
return "", "", "No test examples loaded. Set TEST_JSON_PATH or add test.json at repo root."
ex = random.choice(TEST_EXAMPLES)
co_target = 'lean' if target == "Rocq" else 'rocq'
target = 'rocq' if target == 'Rocq' else 'lean'
term = ex[co_target]['term']
dependencies = ex[target]['dependencies']
return term, dependencies, "\n".join(ex[target]['initial_goal'])
# NEW: hot-reload the test set
def _reload_test_set():
global TEST_EXAMPLES
TEST_EXAMPLES = load_test_examples()
return gr.update(value=f"Reloaded {len(TEST_EXAMPLES)} test examples from {TEST_JSON_PATH}.")
with gr.Blocks(title="Proof translator (ZeroGPU, FP8)") as demo:
gr.Markdown(
"# Lean ↔ Rocq proof translator\n"
"Write a proof term (either Lean or Rocq), "
"then write dependencies in the target language appearing in the source proof term.\n\n"
"You can also use **🎲 Draw test example** to pull a sample from the test set."
)
with gr.Row():
sample_btn = gr.Button("🎲 Draw test example", variant="secondary")
reload_test_btn = gr.Button("Reload test set", variant="secondary")
with gr.Row():
term_box = gr.Code(
label="Pretty-printed proof term (source language)",
language=None,
interactive=True,
lines=18,
)
dep_box = gr.Code(
label="Dependencies contain in the proof term (target language)",
language=None,
interactive=True,
lines=18,
)
initial_goals_box = gr.Code(
label="Initial goal",
language=None,
interactive=True,
lines=18,
)
language_box = gr.Dropdown(
["Lean 4", "Rocq"], label="Target language"
)
with gr.Row():
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
max_new = gr.Slider(256, 8192, value=4096, step=128, label="max_new_tokens")
# Output panels: model vs baseline/truth
with gr.Row():
out = gr.Markdown(label="Generated proof")
btn = gr.Button("Translate", variant="primary")
test_notice = gr.Markdown("")
sample_btn.click(_sample_test_example, inputs=language_box, outputs=[term_box, dep_box, initial_goals_box])
reload_test_btn.click(_reload_test_set, inputs=None, outputs=test_notice)
btn.click(
generate,
inputs=[term_box, dep_box, initial_goals_box, language_box, temperature, top_p, max_new],
outputs=out,
concurrency_limit=1,
)
demo.queue(max_size=20, default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch()