File size: 1,939 Bytes
5b30d83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | """
Inference script for docstring generation from Python code.
Uses Hugging Face Transformers (T5 or CodeT5).
"""
import argparse
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
def generate_docstring(
code: str,
model_name: str = "t5-small",
max_length: int = 128,
num_beams: int = 4,
device: str = None,
) -> str:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
# T5 expects a prefix for the task; we use "summarize:" for generic text/code summary
input_text = "summarize: " + code
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
return tokenizer.decode(out[0], skip_special_tokens=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True, help="Python code snippet (or path to file)")
parser.add_argument("--model_name", type=str, default="t5-small")
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--num_beams", type=int, default=4)
args = parser.parse_args()
code = args.input
if len(code) < 260 and code.endswith(".py"):
try:
with open(code, "r") as f:
code = f.read()
except Exception:
pass
docstring = generate_docstring(
code,
model_name=args.model_name,
max_length=args.max_length,
num_beams=args.num_beams,
)
print(docstring)
if __name__ == "__main__":
main()
|