import gradio as gr from transformers import RobertaTokenizer, RobertaForMaskedLM import torch # Load CodeBERT model and tokenizer model_name = "microsoft/codebert-base-mlm" tokenizer = RobertaTokenizer.from_pretrained(model_name) model = RobertaForMaskedLM.from_pretrained(model_name) def predict_masked_code(code_with_mask, top_k=5): """ Predict the masked token in code. Use to indicate where to predict. """ try: # Replace with the tokenizer's mask token code_with_mask = code_with_mask.replace("", tokenizer.mask_token) # Tokenize input inputs = tokenizer(code_with_mask, return_tensors="pt") # Find the position of the mask token mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] if len(mask_token_index) == 0: return "Error: No token found in the input. Please include where you want predictions." # Get predictions with torch.no_grad(): outputs = model(**inputs) predictions = outputs.logits # Get top-k predictions for the mask token mask_token_logits = predictions[0, mask_token_index, :] top_tokens = torch.topk(mask_token_logits, top_k, dim=1) results = [] for i, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist())): predicted_token = tokenizer.decode([token_id]) filled_code = code_with_mask.replace(tokenizer.mask_token, predicted_token) results.append(f"{i+1}. {predicted_token} (score: {score:.2f})\n Code: {filled_code}") return "\n\n".join(results) except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="CodeBERT Masked Language Model") as demo: gr.Markdown( """ # CodeBERT Masked Language Model This model predicts masked tokens in code. Use `` to indicate where you want predictions. ### Examples: - `def (x, y): return x + y` - `import ` - `for i in (10):` - `x = [1, 2, 3]; y = x.()` """ ) with gr.Row(): with gr.Column(): code_input = gr.Textbox( label="Code with ", placeholder="Enter code with token...", lines=5, value="def (x, y):\n return x + y" ) top_k_slider = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of predictions" ) predict_btn = gr.Button("Predict", variant="primary") with gr.Column(): output = gr.Textbox( label="Predictions", lines=15, interactive=False ) # Examples gr.Examples( examples=[ ["def (x, y):\n return x + y", 5], ["import ", 5], ["for i in (10):", 5], ["x = [1, 2, 3]\ny = x.()", 5], ["if x 0:", 5], ["class :", 5], ], inputs=[code_input, top_k_slider], ) predict_btn.click( fn=predict_masked_code, inputs=[code_input, top_k_slider], outputs=output ) if __name__ == "__main__": demo.launch()