|
|
import os
|
|
|
import argparse
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
|
|
|
|
def parse_args():
|
|
|
p = argparse.ArgumentParser(description="์ฝ๋ ์ทจ์ฝ์ ์ง๋จ LLAMA ์ถ๋ก ")
|
|
|
p.add_argument("--model", type=str, required=True, help="๋ณํฉ๋ ๋ชจ๋ธ ๊ฒฝ๋ก (์์: ./merged-vuln-detector)")
|
|
|
p.add_argument("--code", type=str, help="์ง์ ์
๋ ฅํ ์ฝ๋(์ ํ)")
|
|
|
p.add_argument("--code_file", type=str, help="์ถ๋ก ํ ํจ์ ์ฝ๋ ํ์ผ ๊ฒฝ๋ก(์ ํ)")
|
|
|
p.add_argument("--max_new_tokens", type=int, default=512)
|
|
|
p.add_argument("--dtype", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
|
|
|
return p.parse_args()
|
|
|
|
|
|
|
|
|
def resolve_dtype(dtype_flag):
|
|
|
if dtype_flag == "fp32":
|
|
|
return torch.float32
|
|
|
if dtype_flag in ["fp16", "auto"]:
|
|
|
return torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
if dtype_flag == "bf16":
|
|
|
return torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
|
return torch.float32
|
|
|
|
|
|
def build_prompt(code):
|
|
|
|
|
|
return (
|
|
|
"Analyze the security vulnerabilities in the following code.\n"
|
|
|
+ code
|
|
|
+ "\n\nAnalysis:\n"
|
|
|
)
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
dtype = resolve_dtype(args.dtype)
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
args.model,
|
|
|
dtype=dtype,
|
|
|
device_map="auto"
|
|
|
)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
if args.code:
|
|
|
test_code = args.code
|
|
|
print("[์ง์ ์
๋ ฅ ์ฝ๋ ์ฌ์ฉ]")
|
|
|
elif args.code_file and os.path.exists(args.code_file):
|
|
|
with open(args.code_file, "r", encoding="utf-8") as f:
|
|
|
test_code = f.read()
|
|
|
print(f"[์ฝ๋ ํ์ผ ๋ก๋] {args.code_file}")
|
|
|
else:
|
|
|
|
|
|
test_code = (
|
|
|
"def login(username, password):\n"
|
|
|
" query = f\"SELECT * FROM users WHERE username='{username}' AND password='{password}'\"\n"
|
|
|
" cursor.execute(query)\n"
|
|
|
" return cursor.fetchone()\n"
|
|
|
)
|
|
|
print("[๊ธฐ๋ณธ ์์ ์ฝ๋ ์ฌ์ฉ]")
|
|
|
|
|
|
print("="*50)
|
|
|
print(test_code)
|
|
|
print("="*50)
|
|
|
|
|
|
prompt = build_prompt(test_code)
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
|
|
gen_cfg = GenerationConfig(
|
|
|
max_new_tokens=args.max_new_tokens,
|
|
|
do_sample=False,
|
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
)
|
|
|
with torch.inference_mode():
|
|
|
output = model.generate(**inputs, generation_config=gen_cfg)
|
|
|
|
|
|
input_len = inputs.input_ids.shape[1]
|
|
|
generated_ids = output[0, input_len:]
|
|
|
result = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
|
|
|
print("[์ทจ์ฝ์ ์ง๋จ ๊ฒฐ๊ณผ]")
|
|
|
print(result.strip())
|
|
|
print("="*50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|