""" Inference script for docstring generation from Python code. Uses Hugging Face Transformers (T5 or CodeT5). """ import argparse from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch def generate_docstring( code: str, model_name: str = "t5-small", max_length: int = 128, num_beams: int = 4, device: str = None, ) -> str: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) # T5 expects a prefix for the task; we use "summarize:" for generic text/code summary input_text = "summarize: " + code inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device) with torch.no_grad(): out = model.generate( **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True, ) return tokenizer.decode(out[0], skip_special_tokens=True) def main(): parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, required=True, help="Python code snippet (or path to file)") parser.add_argument("--model_name", type=str, default="t5-small") parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--num_beams", type=int, default=4) args = parser.parse_args() code = args.input if len(code) < 260 and code.endswith(".py"): try: with open(code, "r") as f: code = f.read() except Exception: pass docstring = generate_docstring( code, model_name=args.model_name, max_length=args.max_length, num_beams=args.num_beams, ) print(docstring) if __name__ == "__main__": main()