Text Generation
Transformers
Safetensors
English
llama-3.2-1B-Instruct
vuln_detector / llama_predict.py
cycloevan's picture
Upload 11 files
ab1c015 verified
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
def parse_args():
p = argparse.ArgumentParser(description="์ฝ”๋“œ ์ทจ์•ฝ์  ์ง„๋‹จ LLAMA ์ถ”๋ก ")
p.add_argument("--model", type=str, required=True, help="๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ ๊ฒฝ๋กœ (์˜ˆ์‹œ: ./merged-vuln-detector)")
p.add_argument("--code", type=str, help="์ง์ ‘ ์ž…๋ ฅํ•œ ์ฝ”๋“œ(์„ ํƒ)")
p.add_argument("--code_file", type=str, help="์ถ”๋ก ํ•  ํ•จ์ˆ˜ ์ฝ”๋“œ ํŒŒ์ผ ๊ฒฝ๋กœ(์„ ํƒ)")
p.add_argument("--max_new_tokens", type=int, default=512)
p.add_argument("--dtype", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
return p.parse_args()
def resolve_dtype(dtype_flag):
if dtype_flag == "fp32":
return torch.float32
if dtype_flag in ["fp16", "auto"]:
return torch.float16 if torch.cuda.is_available() else torch.float32
if dtype_flag == "bf16":
return torch.bfloat16 if torch.cuda.is_available() else torch.float32
return torch.float32
def build_prompt(code):
# ํ•™์Šต๊ณผ ๋™์ผํ•˜๊ฒŒ ๋ณด์•ˆ ์ „๋ฌธ๊ฐ€ ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
return (
"Analyze the security vulnerabilities in the following code.\n"
+ code
+ "\n\nAnalysis:\n"
)
def main():
args = parse_args()
dtype = resolve_dtype(args.dtype)
# ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
dtype=dtype,
device_map="auto"
)
model.eval()
# ์ž…๋ ฅ ์ฝ”๋“œ ๋ฐ›๊ธฐ
if args.code:
test_code = args.code
print("[์ง์ ‘ ์ž…๋ ฅ ์ฝ”๋“œ ์‚ฌ์šฉ]")
elif args.code_file and os.path.exists(args.code_file):
with open(args.code_file, "r", encoding="utf-8") as f:
test_code = f.read()
print(f"[์ฝ”๋“œ ํŒŒ์ผ ๋กœ๋“œ] {args.code_file}")
else:
# ์˜ˆ์‹œ ์ฝ”๋“œ (์ทจ์•ฝ์  ํฌํ•จ ์˜ˆ)
test_code = (
"def login(username, password):\n"
" query = f\"SELECT * FROM users WHERE username='{username}' AND password='{password}'\"\n"
" cursor.execute(query)\n"
" return cursor.fetchone()\n"
)
print("[๊ธฐ๋ณธ ์˜ˆ์‹œ ์ฝ”๋“œ ์‚ฌ์šฉ]")
print("="*50)
print(test_code)
print("="*50)
prompt = build_prompt(test_code)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
gen_cfg = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
with torch.inference_mode():
output = model.generate(**inputs, generation_config=gen_cfg)
input_len = inputs.input_ids.shape[1]
generated_ids = output[0, input_len:]
result = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("[์ทจ์•ฝ์  ์ง„๋‹จ ๊ฒฐ๊ณผ]")
print(result.strip())
print("="*50)
if __name__ == "__main__":
main()