File size: 3,230 Bytes
ab1c015 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
def parse_args():
p = argparse.ArgumentParser(description="์ฝ๋ ์ทจ์ฝ์ ์ง๋จ LLAMA ์ถ๋ก ")
p.add_argument("--model", type=str, required=True, help="๋ณํฉ๋ ๋ชจ๋ธ ๊ฒฝ๋ก (์์: ./merged-vuln-detector)")
p.add_argument("--code", type=str, help="์ง์ ์
๋ ฅํ ์ฝ๋(์ ํ)")
p.add_argument("--code_file", type=str, help="์ถ๋ก ํ ํจ์ ์ฝ๋ ํ์ผ ๊ฒฝ๋ก(์ ํ)")
p.add_argument("--max_new_tokens", type=int, default=512)
p.add_argument("--dtype", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
return p.parse_args()
def resolve_dtype(dtype_flag):
if dtype_flag == "fp32":
return torch.float32
if dtype_flag in ["fp16", "auto"]:
return torch.float16 if torch.cuda.is_available() else torch.float32
if dtype_flag == "bf16":
return torch.bfloat16 if torch.cuda.is_available() else torch.float32
return torch.float32
def build_prompt(code):
# ํ์ต๊ณผ ๋์ผํ๊ฒ ๋ณด์ ์ ๋ฌธ๊ฐ ์คํ์ผ ํ๋กฌํํธ ์์ฑ
return (
"Analyze the security vulnerabilities in the following code.\n"
+ code
+ "\n\nAnalysis:\n"
)
def main():
args = parse_args()
dtype = resolve_dtype(args.dtype)
# ๋ชจ๋ธ/ํ ํฌ๋์ด์ ๋ก๋
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
dtype=dtype,
device_map="auto"
)
model.eval()
# ์
๋ ฅ ์ฝ๋ ๋ฐ๊ธฐ
if args.code:
test_code = args.code
print("[์ง์ ์
๋ ฅ ์ฝ๋ ์ฌ์ฉ]")
elif args.code_file and os.path.exists(args.code_file):
with open(args.code_file, "r", encoding="utf-8") as f:
test_code = f.read()
print(f"[์ฝ๋ ํ์ผ ๋ก๋] {args.code_file}")
else:
# ์์ ์ฝ๋ (์ทจ์ฝ์ ํฌํจ ์)
test_code = (
"def login(username, password):\n"
" query = f\"SELECT * FROM users WHERE username='{username}' AND password='{password}'\"\n"
" cursor.execute(query)\n"
" return cursor.fetchone()\n"
)
print("[๊ธฐ๋ณธ ์์ ์ฝ๋ ์ฌ์ฉ]")
print("="*50)
print(test_code)
print("="*50)
prompt = build_prompt(test_code)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
gen_cfg = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
with torch.inference_mode():
output = model.generate(**inputs, generation_config=gen_cfg)
input_len = inputs.input_ids.shape[1]
generated_ids = output[0, input_len:]
result = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("[์ทจ์ฝ์ ์ง๋จ ๊ฒฐ๊ณผ]")
print(result.strip())
print("="*50)
if __name__ == "__main__":
main()
|