syeedalireza's picture
Upload folder using huggingface_hub
5b30d83 verified
"""
Inference script for docstring generation from Python code.
Uses Hugging Face Transformers (T5 or CodeT5).
"""
import argparse
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
def generate_docstring(
code: str,
model_name: str = "t5-small",
max_length: int = 128,
num_beams: int = 4,
device: str = None,
) -> str:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
# T5 expects a prefix for the task; we use "summarize:" for generic text/code summary
input_text = "summarize: " + code
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
return tokenizer.decode(out[0], skip_special_tokens=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True, help="Python code snippet (or path to file)")
parser.add_argument("--model_name", type=str, default="t5-small")
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--num_beams", type=int, default=4)
args = parser.parse_args()
code = args.input
if len(code) < 260 and code.endswith(".py"):
try:
with open(code, "r") as f:
code = f.read()
except Exception:
pass
docstring = generate_docstring(
code,
model_name=args.model_name,
max_length=args.max_length,
num_beams=args.num_beams,
)
print(docstring)
if __name__ == "__main__":
main()