Text Generation
Transformers
Safetensors
English
llama-3.2-1B-Instruct
File size: 3,230 Bytes
ab1c015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

def parse_args():
    p = argparse.ArgumentParser(description="์ฝ”๋“œ ์ทจ์•ฝ์  ์ง„๋‹จ LLAMA ์ถ”๋ก ")
    p.add_argument("--model", type=str, required=True, help="๋ณ‘ํ•ฉ๋œ ๋ชจ๋ธ ๊ฒฝ๋กœ (์˜ˆ์‹œ: ./merged-vuln-detector)")
    p.add_argument("--code", type=str, help="์ง์ ‘ ์ž…๋ ฅํ•œ ์ฝ”๋“œ(์„ ํƒ)")
    p.add_argument("--code_file", type=str, help="์ถ”๋ก ํ•  ํ•จ์ˆ˜ ์ฝ”๋“œ ํŒŒ์ผ ๊ฒฝ๋กœ(์„ ํƒ)")
    p.add_argument("--max_new_tokens", type=int, default=512)
    p.add_argument("--dtype", type=str, default="fp16", choices=["fp32", "fp16", "bf16"])
    return p.parse_args()


def resolve_dtype(dtype_flag):
    if dtype_flag == "fp32":
        return torch.float32
    if dtype_flag in ["fp16", "auto"]:
        return torch.float16 if torch.cuda.is_available() else torch.float32
    if dtype_flag == "bf16":
        return torch.bfloat16 if torch.cuda.is_available() else torch.float32
    return torch.float32

def build_prompt(code):
    # ํ•™์Šต๊ณผ ๋™์ผํ•˜๊ฒŒ ๋ณด์•ˆ ์ „๋ฌธ๊ฐ€ ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
    return (
        "Analyze the security vulnerabilities in the following code.\n"
        + code
        + "\n\nAnalysis:\n"
    )

def main():
    args = parse_args()
    dtype = resolve_dtype(args.dtype)

    # ๋ชจ๋ธ/ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        dtype=dtype,
        device_map="auto"
    )
    model.eval()

    # ์ž…๋ ฅ ์ฝ”๋“œ ๋ฐ›๊ธฐ
    if args.code:
        test_code = args.code
        print("[์ง์ ‘ ์ž…๋ ฅ ์ฝ”๋“œ ์‚ฌ์šฉ]")
    elif args.code_file and os.path.exists(args.code_file):
        with open(args.code_file, "r", encoding="utf-8") as f:
            test_code = f.read()
        print(f"[์ฝ”๋“œ ํŒŒ์ผ ๋กœ๋“œ] {args.code_file}")
    else:
        # ์˜ˆ์‹œ ์ฝ”๋“œ (์ทจ์•ฝ์  ํฌํ•จ ์˜ˆ)
        test_code = (
            "def login(username, password):\n"
            "    query = f\"SELECT * FROM users WHERE username='{username}' AND password='{password}'\"\n"
            "    cursor.execute(query)\n"
            "    return cursor.fetchone()\n"
        )
        print("[๊ธฐ๋ณธ ์˜ˆ์‹œ ์ฝ”๋“œ ์‚ฌ์šฉ]")

    print("="*50)
    print(test_code)
    print("="*50)

    prompt = build_prompt(test_code)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    gen_cfg = GenerationConfig(
        max_new_tokens=args.max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    with torch.inference_mode():
        output = model.generate(**inputs, generation_config=gen_cfg)

    input_len = inputs.input_ids.shape[1]
    generated_ids = output[0, input_len:]
    result = tokenizer.decode(generated_ids, skip_special_tokens=True)

    print("[์ทจ์•ฝ์  ์ง„๋‹จ ๊ฒฐ๊ณผ]")
    print(result.strip())
    print("="*50)

if __name__ == "__main__":
    main()