syeedalireza's picture
Upload folder using huggingface_hub
88a03c0 verified
"""
Inference for docstring generation. Uses T5 (cached after first load).
"""
import torch
_cache = {}
def generate_docstring(
code: str,
model_name: str = "t5-small",
max_length: int = 128,
num_beams: int = 4,
device: str = None,
) -> str:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name not in _cache:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
_cache[model_name] = {
"tokenizer": AutoTokenizer.from_pretrained(model_name),
"model": AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device),
}
tokenizer = _cache[model_name]["tokenizer"]
model = _cache[model_name]["model"]
input_text = "summarize: " + code
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
out = model.generate(**inputs, max_length=max_length, num_beams=num_beams, early_stopping=True)
return tokenizer.decode(out[0], skip_special_tokens=True)