Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import RobertaTokenizer, RobertaForMaskedLM | |
| import torch | |
| # Load CodeBERT model and tokenizer | |
| model_name = "microsoft/codebert-base-mlm" | |
| tokenizer = RobertaTokenizer.from_pretrained(model_name) | |
| model = RobertaForMaskedLM.from_pretrained(model_name) | |
| def predict_masked_code(code_with_mask, top_k=5): | |
| """ | |
| Predict the masked token in code. | |
| Use <mask> to indicate where to predict. | |
| """ | |
| try: | |
| # Replace <mask> with the tokenizer's mask token | |
| code_with_mask = code_with_mask.replace("<mask>", tokenizer.mask_token) | |
| # Tokenize input | |
| inputs = tokenizer(code_with_mask, return_tensors="pt") | |
| # Find the position of the mask token | |
| mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
| if len(mask_token_index) == 0: | |
| return "Error: No <mask> token found in the input. Please include <mask> where you want predictions." | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predictions = outputs.logits | |
| # Get top-k predictions for the mask token | |
| mask_token_logits = predictions[0, mask_token_index, :] | |
| top_tokens = torch.topk(mask_token_logits, top_k, dim=1) | |
| results = [] | |
| for i, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist())): | |
| predicted_token = tokenizer.decode([token_id]) | |
| filled_code = code_with_mask.replace(tokenizer.mask_token, predicted_token) | |
| results.append(f"{i+1}. {predicted_token} (score: {score:.2f})\n Code: {filled_code}") | |
| return "\n\n".join(results) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="CodeBERT Masked Language Model") as demo: | |
| gr.Markdown( | |
| """ | |
| # CodeBERT Masked Language Model | |
| This model predicts masked tokens in code. Use `<mask>` to indicate where you want predictions. | |
| ### Examples: | |
| - `def <mask>(x, y): return x + y` | |
| - `import <mask>` | |
| - `for i in <mask>(10):` | |
| - `x = [1, 2, 3]; y = x.<mask>()` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| code_input = gr.Textbox( | |
| label="Code with <mask>", | |
| placeholder="Enter code with <mask> token...", | |
| lines=5, | |
| value="def <mask>(x, y):\n return x + y" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of predictions" | |
| ) | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="Predictions", | |
| lines=15, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["def <mask>(x, y):\n return x + y", 5], | |
| ["import <mask>", 5], | |
| ["for i in <mask>(10):", 5], | |
| ["x = [1, 2, 3]\ny = x.<mask>()", 5], | |
| ["if x <mask> 0:", 5], | |
| ["class <mask>:", 5], | |
| ], | |
| inputs=[code_input, top_k_slider], | |
| ) | |
| predict_btn.click( | |
| fn=predict_masked_code, | |
| inputs=[code_input, top_k_slider], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |