File size: 3,574 Bytes
c995694
7ef9108
 
c995694
7ef9108
 
 
 
c995694
7ef9108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c995694
7ef9108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c995694
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from transformers import RobertaTokenizer, RobertaForMaskedLM
import torch

# Load CodeBERT model and tokenizer
model_name = "microsoft/codebert-base-mlm"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForMaskedLM.from_pretrained(model_name)

def predict_masked_code(code_with_mask, top_k=5):
    """
    Predict the masked token in code.
    Use <mask> to indicate where to predict.
    """
    try:
        # Replace <mask> with the tokenizer's mask token
        code_with_mask = code_with_mask.replace("<mask>", tokenizer.mask_token)
        
        # Tokenize input
        inputs = tokenizer(code_with_mask, return_tensors="pt")
        
        # Find the position of the mask token
        mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
        
        if len(mask_token_index) == 0:
            return "Error: No <mask> token found in the input. Please include <mask> where you want predictions."
        
        # Get predictions
        with torch.no_grad():
            outputs = model(**inputs)
            predictions = outputs.logits
        
        # Get top-k predictions for the mask token
        mask_token_logits = predictions[0, mask_token_index, :]
        top_tokens = torch.topk(mask_token_logits, top_k, dim=1)
        
        results = []
        for i, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist())):
            predicted_token = tokenizer.decode([token_id])
            filled_code = code_with_mask.replace(tokenizer.mask_token, predicted_token)
            results.append(f"{i+1}. {predicted_token} (score: {score:.2f})\n   Code: {filled_code}")
        
        return "\n\n".join(results)
    
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
    gr.Markdown(
        """
        # CodeBERT Masked Language Model
        
        This model predicts masked tokens in code. Use `<mask>` to indicate where you want predictions.
        
        ### Examples:
        - `def <mask>(x, y): return x + y`
        - `import <mask>`
        - `for i in <mask>(10):`
        - `x = [1, 2, 3]; y = x.<mask>()`
        """
    )
    
    with gr.Row():
        with gr.Column():
            code_input = gr.Textbox(
                label="Code with <mask>",
                placeholder="Enter code with <mask> token...",
                lines=5,
                value="def <mask>(x, y):\n    return x + y"
            )
            top_k_slider = gr.Slider(
                minimum=1,
                maximum=10,
                value=5,
                step=1,
                label="Number of predictions"
            )
            predict_btn = gr.Button("Predict", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(
                label="Predictions",
                lines=15,
                interactive=False
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["def <mask>(x, y):\n    return x + y", 5],
            ["import <mask>", 5],
            ["for i in <mask>(10):", 5],
            ["x = [1, 2, 3]\ny = x.<mask>()", 5],
            ["if x <mask> 0:", 5],
            ["class <mask>:", 5],
        ],
        inputs=[code_input, top_k_slider],
    )
    
    predict_btn.click(
        fn=predict_masked_code,
        inputs=[code_input, top_k_slider],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()