File size: 1,681 Bytes
480d353
 
 
 
 
 
 
ff23930
2f2ff9f
ff23930
1b052c1
480d353
 
2f2ff9f
ff23930
2f2ff9f
ff23930
 
 
 
2f2ff9f
ff23930
 
2f2ff9f
ff23930
2f2ff9f
ff23930
 
2f2ff9f
ff23930
 
 
2f2ff9f
ff23930
 
2f2ff9f
ff23930
514dcf5
2f2ff9f
514dcf5
6d95f34
514dcf5
6d95f34
 
 
 
 
 
514dcf5
 
 
6d95f34
514dcf5
6d95f34
 
 
 
 
 
 
 
 
 
 
480d353
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
license: apache-2.0
language:
- en
base_model:
- deepseek-ai/DeepSeek-Prover-V1.5-SFT
---
# Lean Conjecturer

This model generates Lean 4 conjectures from given theorem statements.
![Alt text](photo.png)
For more details, please see https://github.com/Slim205/RL-Lean


## Usage

```python
import re
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

model_name = "Slim205/Lean-conjecturer"
tokenizer = AutoTokenizer.from_pretrained(model_name)

LEAN4_DEFAULT_HEADER = "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"

START_THM, END_THM = "<theorem>", "</theorem>"
START_CONJ, END_CONJ = "<conjecture>", "</conjecture>"

def get_prompt(theorem, context=''):
    return f'Complete the following Lean 4 code:\n\n```lean4\n{context.strip()}\n' \
           f'{START_THM}\n{theorem.strip()}\n{END_THM}\n{START_CONJ}\n theorem'

text = "theorem mathd_numbertheory_3 : (∑ x in Finset.range 10, (x + 1) ^ 2) % 10 = 5 := by"
model_inputs = [get_prompt(text)]

print(model_inputs[0])
```

Example output:

```lean4
<theorem>
theorem mathd_numbertheory_3 : (∑ x in Finset.range 10, (x + 1) ^ 2) % 10 = 5 := by
</theorem>
<conjecture>
 theorem mathd_numbertheory_24 : ∑ k in Finset.range 11, k ^ 2 = 385 := by
</conjecture>
```

## Inference

```python
model = LLM(model=model_name, seed=1, trust_remote_code=True, swap_space=8, tensor_parallel_size=1, max_model_len=4096)

sampling_params = SamplingParams(
    temperature=1,
    max_tokens=2048,
    top_p=0.95,
    n=1,
)

model_outputs = model.generate(model_inputs, sampling_params, use_tqdm=True)
print(model_outputs[0].outputs[0].text)
```