InventorsHub commited on
Commit
c5b3be4
·
verified ·
1 Parent(s): 0d4b6b6

Update safety_module.py

Browse files
Files changed (1) hide show
  1. safety_module.py +92 -80
safety_module.py CHANGED
@@ -1,81 +1,93 @@
1
- from llama_cpp import Llama
2
- from huggingface_hub import hf_hub_download
3
-
4
-
5
- # Download only the safety model shard from the model repo
6
- model_path = hf_hub_download(
7
- repo_id="Inventors-Hub/SwarmChat-models",
8
- repo_type="model",
9
- filename="llama-guard-3-8b-q4_k_m.gguf",
10
- )
11
-
12
-
13
- # Load the Llama model for safety classification
14
- # model_path = r".\models\llama-guard-3-8b-q4_k_m.gguf"
15
-
16
-
17
- # llm = Llama(model_path=model_path, n_ctx=1024)
18
- llm = Llama(
19
- model_path=model_path,
20
- n_ctx=512, # down from 4096
21
- low_vram=True, # llama.cpp low-vram mode
22
- f16_kv=True, # half-precision kv cache
23
- use_mmap=True, # mmap file
24
- use_mlock=False,
25
- )
26
-
27
-
28
- def check_safety(text, checkbox):
29
-
30
- if checkbox:
31
- return "Safe"
32
-
33
- categories = {
34
- "S1": "Violent Crimes.",
35
- "S2": "Non-Violent Crimes.",
36
- "S3": "Sex-Related Crimes.",
37
- "S4": "Child Sexual Exploitation.",
38
- "S5": "Defamation.",
39
- "S6": "Specialized Advice.",
40
- "S7": "Privacy.",
41
- "S8": "Intellectual Property.",
42
- "S9": "Indiscriminate Weapons.",
43
- "S10": "Hate.",
44
- "S11": "Suicide & Self-Harm.",
45
- "S12": "Sexual Content.",
46
- "S13": "Elections.",
47
- "S14": "Code Interpreter Abuse.",
48
- }
49
-
50
- input_prompt = f"""
51
- <|begin_of_text|><|start_header_id|>user<|end_header_id|>
52
-
53
- Task: Determine if the following text contains unsafe content according to our safety policy. Classify the content into one or more of the categories below if unsafe.
54
-
55
- <BEGIN UNSAFE CONTENT CATEGORIES>
56
- {', '.join([f'{key}: {value}' for key, value in categories.items()])}
57
- <END UNSAFE CONTENT CATEGORIES>
58
-
59
- Text: {text}
60
-
61
- Provide your safety assessment:
62
- - Respond with "safe" if the text is appropriate.
63
- - If unsafe, respond with: "unsafe" followed by a comma-separated list of the violated categories (e.g., S1, S2).
64
-
65
- <|eot_id|>
66
- """
67
-
68
- # Generate a response from the model
69
- output = llm(input_prompt, temperature=0)
70
- # Extract the model's response
71
- response = output.get("choices", [{}])[0].get("text", "").strip()
72
- print("\n\n", response, "\n\n")
73
- # Parse the response
74
- if response.lower().startswith("safe"):
75
- return "Safe"
76
- else:
77
- unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
78
- return f"Unsafe: This prompt is categorized as '{unsafe_categories}'"
79
-
80
- # unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
 
 
 
 
 
 
 
 
 
 
 
 
81
  # return f"Unsafe: This prompt categorized as '{unsafe_categories}'"
 
1
+ from llama_cpp import Llama
2
+ from huggingface_hub import hf_hub_download
3
+
4
+
5
+ # Download only the safety model shard from the model repo
6
+ model_path = hf_hub_download(
7
+ repo_id="Inventors-Hub/SwarmChat-models",
8
+ repo_type="model",
9
+ filename="llama-guard-3-8b-q4_k_m.gguf",
10
+ )
11
+
12
+
13
+ # Load the Llama model for safety classification
14
+ # model_path = r".\models\llama-guard-3-8b-q4_k_m.gguf"
15
+
16
+
17
+ # llm = Llama(model_path=model_path, n_ctx=1024)
18
+ # llm = Llama(
19
+ # model_path=model_path,
20
+ # n_ctx=512, # down from 4096
21
+ # low_vram=True, # llama.cpp low-vram mode
22
+ # f16_kv=True, # half-precision kv cache
23
+ # use_mmap=True, # mmap file
24
+ # use_mlock=False,
25
+ # )
26
+ llm = None
27
+ @spaces.GPU
28
+ def llm_gpu():
29
+
30
+ llm = Llama(
31
+ model_path=model_path,
32
+ n_ctx=512, # down from 4096
33
+ low_vram=True, # llama.cpp low-vram mode
34
+ f16_kv=True, # half-precision kv cache
35
+ use_mmap=True, # mmap file
36
+ use_mlock=False,
37
+ )
38
+ return llm
39
+
40
+ def check_safety(text, checkbox):
41
+
42
+ if checkbox:
43
+ return "Safe"
44
+
45
+ categories = {
46
+ "S1": "Violent Crimes.",
47
+ "S2": "Non-Violent Crimes.",
48
+ "S3": "Sex-Related Crimes.",
49
+ "S4": "Child Sexual Exploitation.",
50
+ "S5": "Defamation.",
51
+ "S6": "Specialized Advice.",
52
+ "S7": "Privacy.",
53
+ "S8": "Intellectual Property.",
54
+ "S9": "Indiscriminate Weapons.",
55
+ "S10": "Hate.",
56
+ "S11": "Suicide & Self-Harm.",
57
+ "S12": "Sexual Content.",
58
+ "S13": "Elections.",
59
+ "S14": "Code Interpreter Abuse.",
60
+ }
61
+
62
+ input_prompt = f"""
63
+ <|begin_of_text|><|start_header_id|>user<|end_header_id|>
64
+
65
+ Task: Determine if the following text contains unsafe content according to our safety policy. Classify the content into one or more of the categories below if unsafe.
66
+
67
+ <BEGIN UNSAFE CONTENT CATEGORIES>
68
+ {', '.join([f'{key}: {value}' for key, value in categories.items()])}
69
+ <END UNSAFE CONTENT CATEGORIES>
70
+
71
+ Text: {text}
72
+
73
+ Provide your safety assessment:
74
+ - Respond with "safe" if the text is appropriate.
75
+ - If unsafe, respond with: "unsafe" followed by a comma-separated list of the violated categories (e.g., S1, S2).
76
+
77
+ <|eot_id|>
78
+ """
79
+
80
+ # Generate a response from the model
81
+ output = llm(input_prompt, temperature=0)
82
+ # Extract the model's response
83
+ response = output.get("choices", [{}])[0].get("text", "").strip()
84
+ print("\n\n", response, "\n\n")
85
+ # Parse the response
86
+ if response.lower().startswith("safe"):
87
+ return "Safe"
88
+ else:
89
+ unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
90
+ return f"Unsafe: This prompt is categorized as '{unsafe_categories}'"
91
+
92
+ # unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
93
  # return f"Unsafe: This prompt categorized as '{unsafe_categories}'"