| """
|
| Inference script for docstring generation from Python code.
|
| Uses Hugging Face Transformers (T5 or CodeT5).
|
| """
|
|
|
| import argparse
|
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| import torch
|
|
|
|
|
| def generate_docstring(
|
| code: str,
|
| model_name: str = "t5-small",
|
| max_length: int = 128,
|
| num_beams: int = 4,
|
| device: str = None,
|
| ) -> str:
|
| if device is None:
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
|
|
|
|
| input_text = "summarize: " + code
|
| inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(device)
|
|
|
| with torch.no_grad():
|
| out = model.generate(
|
| **inputs,
|
| max_length=max_length,
|
| num_beams=num_beams,
|
| early_stopping=True,
|
| )
|
|
|
| return tokenizer.decode(out[0], skip_special_tokens=True)
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--input", type=str, required=True, help="Python code snippet (or path to file)")
|
| parser.add_argument("--model_name", type=str, default="t5-small")
|
| parser.add_argument("--max_length", type=int, default=128)
|
| parser.add_argument("--num_beams", type=int, default=4)
|
| args = parser.parse_args()
|
|
|
| code = args.input
|
| if len(code) < 260 and code.endswith(".py"):
|
| try:
|
| with open(code, "r") as f:
|
| code = f.read()
|
| except Exception:
|
| pass
|
|
|
| docstring = generate_docstring(
|
| code,
|
| model_name=args.model_name,
|
| max_length=args.max_length,
|
| num_beams=args.num_beams,
|
| )
|
| print(docstring)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|