Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import RobertaTokenizer, RobertaForMaskedLM | |
| import torch | |
| # Load CodeBERT model and tokenizer | |
| model_name = "microsoft/codebert-base-mlm" | |
| tokenizer = RobertaTokenizer.from_pretrained(model_name) | |
| model = RobertaForMaskedLM.from_pretrained(model_name) | |
| def predict(code, num_predictions=5): | |
| """ | |
| Predict the masked token in code. | |
| Use <mask> to indicate where to predict. | |
| Args: | |
| code: Code snippet with <mask> token | |
| num_predictions: Number of top predictions to return | |
| Returns: | |
| JSON object with predictions | |
| """ | |
| try: | |
| # Replace <mask> with the tokenizer's mask token | |
| code_input = code.replace("<mask>", tokenizer.mask_token) | |
| # Tokenize input | |
| inputs = tokenizer(code_input, return_tensors="pt") | |
| # Find the position of the mask token | |
| mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
| if len(mask_token_index) == 0: | |
| return { | |
| "error": "No <mask> token found in the input. Please include <mask> where you want predictions.", | |
| "predictions": [] | |
| } | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get top-k predictions for the mask token | |
| mask_token_logits = logits[0, mask_token_index, :] | |
| top_tokens = torch.topk(mask_token_logits, num_predictions, dim=1) | |
| predictions = [] | |
| for rank, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist()), 1): | |
| predicted_token = tokenizer.decode([token_id]) | |
| completed_code = code_input.replace(tokenizer.mask_token, predicted_token) | |
| predictions.append({ | |
| "rank": rank, | |
| "token": predicted_token, | |
| "score": round(float(score), 4), | |
| "completed_code": completed_code | |
| }) | |
| return { | |
| "original_code": code, | |
| "predictions": predictions | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "predictions": [] | |
| } | |
| # Create Gradio interface | |
| with gr.Blocks(title="CodeBERT Masked Language Model") as demo: | |
| gr.Markdown( | |
| """ | |
| # CodeBERT Masked Language Model | |
| This model predicts masked tokens in code. Use `<mask>` to indicate where you want predictions. | |
| ### Examples: | |
| - `def <mask>(x, y): return x + y` | |
| - `import <mask>` | |
| - `for i in <mask>(10):` | |
| - `x = [1, 2, 3]; y = x.<mask>()` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| code_input = gr.Textbox( | |
| label="Code with <mask>", | |
| placeholder="Enter code with <mask> token...", | |
| lines=5, | |
| value="def <mask>(x, y):\n return x + y" | |
| ) | |
| num_predictions_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of predictions" | |
| ) | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| with gr.Column(): | |
| output = gr.JSON( | |
| label="Predictions" | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["def <mask>(x, y):\n return x + y", 5], | |
| ["import <mask>", 5], | |
| ["for i in <mask>(10):", 5], | |
| ["x = [1, 2, 3]\ny = x.<mask>()", 5], | |
| ["if x <mask> 0:", 5], | |
| ["class <mask>:", 5], | |
| ], | |
| inputs=[code_input, num_predictions_slider], | |
| ) | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[code_input, num_predictions_slider], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |