File size: 1,093 Bytes
96b5e6a
88a03c0
96b5e6a
 
 
88a03c0
96b5e6a
 
 
 
 
 
 
 
 
 
88a03c0
 
 
 
 
 
 
 
96b5e6a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""

Inference for docstring generation. Uses T5 (cached after first load).

"""
import torch

_cache = {}

def generate_docstring(

    code: str,

    model_name: str = "t5-small",

    max_length: int = 128,

    num_beams: int = 4,

    device: str = None,

) -> str:
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if model_name not in _cache:
        from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
        _cache[model_name] = {
            "tokenizer": AutoTokenizer.from_pretrained(model_name),
            "model": AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device),
        }
    tokenizer = _cache[model_name]["tokenizer"]
    model = _cache[model_name]["model"]
    input_text = "summarize: " + code
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
    with torch.no_grad():
        out = model.generate(**inputs, max_length=max_length, num_beams=num_beams, early_stopping=True)
    return tokenizer.decode(out[0], skip_special_tokens=True)