| """ | |
| Inference for docstring generation. Uses T5 (cached after first load). | |
| """ | |
| import torch | |
| _cache = {} | |
| def generate_docstring( | |
| code: str, | |
| model_name: str = "t5-small", | |
| max_length: int = 128, | |
| num_beams: int = 4, | |
| device: str = None, | |
| ) -> str: | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if model_name not in _cache: | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| _cache[model_name] = { | |
| "tokenizer": AutoTokenizer.from_pretrained(model_name), | |
| "model": AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device), | |
| } | |
| tokenizer = _cache[model_name]["tokenizer"] | |
| model = _cache[model_name]["model"] | |
| input_text = "summarize: " + code | |
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_length=max_length, num_beams=num_beams, early_stopping=True) | |
| return tokenizer.decode(out[0], skip_special_tokens=True) | |