File size: 1,093 Bytes
96b5e6a 88a03c0 96b5e6a 88a03c0 96b5e6a 88a03c0 96b5e6a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | """
Inference for docstring generation. Uses T5 (cached after first load).
"""
import torch
_cache = {}
def generate_docstring(
code: str,
model_name: str = "t5-small",
max_length: int = 128,
num_beams: int = 4,
device: str = None,
) -> str:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name not in _cache:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
_cache[model_name] = {
"tokenizer": AutoTokenizer.from_pretrained(model_name),
"model": AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device),
}
tokenizer = _cache[model_name]["tokenizer"]
model = _cache[model_name]["model"]
input_text = "summarize: " + code
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
out = model.generate(**inputs, max_length=max_length, num_beams=num_beams, early_stopping=True)
return tokenizer.decode(out[0], skip_special_tokens=True)
|