vibesecurityguy commited on
Commit
4a6f75b
Β·
verified Β·
1 Parent(s): 265d130

Upload src/veris_classifier/classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/veris_classifier/classifier.py +48 -17
src/veris_classifier/classifier.py CHANGED
@@ -91,15 +91,28 @@ def load_hf_model():
91
 
92
  def _generate_hf(messages: list[dict], max_new_tokens: int = 1024) -> str:
93
  """Generate a response using the fine-tuned HF model."""
 
 
 
 
 
 
 
 
 
 
 
94
  pipe, tokenizer = load_hf_model()
95
 
96
- outputs = pipe(
97
- messages,
98
- max_new_tokens=max_new_tokens,
99
- do_sample=True,
100
- temperature=0.2,
101
- top_p=0.9,
102
- )
 
 
103
 
104
  return outputs[0]["generated_text"].strip()
105
 
@@ -129,6 +142,32 @@ def _generate_openai(
129
  return response.choices[0].message.content.strip()
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # ── Public API ────────────────────────────────────────────────────────────
133
 
134
 
@@ -155,7 +194,7 @@ def classify_incident(
155
  ]
156
 
157
  if use_hf:
158
- raw = _generate_hf(messages, max_new_tokens=1024)
159
  else:
160
  if client is None:
161
  raise ValueError("OpenAI client required when use_hf=False")
@@ -163,15 +202,7 @@ def classify_incident(
163
  client, messages, model=model, temperature=0.2, json_mode=True
164
  )
165
 
166
- # Parse JSON from response (handle markdown fences if present)
167
- text = raw.strip()
168
- if text.startswith("```"):
169
- # Strip ```json ... ``` wrapper
170
- lines = text.split("\n")
171
- text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
172
- text = text.strip()
173
-
174
- return json.loads(text)
175
 
176
 
177
  def answer_question(
 
91
 
92
  def _generate_hf(messages: list[dict], max_new_tokens: int = 1024) -> str:
93
  """Generate a response using the fine-tuned HF model."""
94
+ return _generate_hf_with_options(messages, max_new_tokens=max_new_tokens)
95
+
96
+
97
+ def _generate_hf_with_options(
98
+ messages: list[dict],
99
+ max_new_tokens: int = 1024,
100
+ do_sample: bool = True,
101
+ temperature: float = 0.2,
102
+ top_p: float = 0.9,
103
+ ) -> str:
104
+ """Generate a response using the fine-tuned HF model with explicit sampling controls."""
105
  pipe, tokenizer = load_hf_model()
106
 
107
+ generate_kwargs = {
108
+ "max_new_tokens": max_new_tokens,
109
+ "do_sample": do_sample,
110
+ }
111
+ if do_sample:
112
+ generate_kwargs["temperature"] = temperature
113
+ generate_kwargs["top_p"] = top_p
114
+
115
+ outputs = pipe(messages, **generate_kwargs)
116
 
117
  return outputs[0]["generated_text"].strip()
118
 
 
142
  return response.choices[0].message.content.strip()
143
 
144
 
145
+ def _parse_json_response(raw: str) -> dict:
146
+ """Parse model output into JSON with light recovery for wrapped text."""
147
+ text = raw.strip()
148
+ try:
149
+ return json.loads(text)
150
+ except json.JSONDecodeError:
151
+ pass
152
+
153
+ if text.startswith("```"):
154
+ lines = text.split("\n")
155
+ text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
156
+ text = text.strip()
157
+ try:
158
+ return json.loads(text)
159
+ except json.JSONDecodeError:
160
+ pass
161
+
162
+ # Recover when the model prepends/appends prose around a JSON object.
163
+ start = text.find("{")
164
+ end = text.rfind("}")
165
+ if start != -1 and end != -1 and end > start:
166
+ return json.loads(text[start : end + 1])
167
+
168
+ raise json.JSONDecodeError("No JSON object found in model output", text, 0)
169
+
170
+
171
  # ── Public API ────────────────────────────────────────────────────────────
172
 
173
 
 
194
  ]
195
 
196
  if use_hf:
197
+ raw = _generate_hf_with_options(messages, max_new_tokens=1024, do_sample=False)
198
  else:
199
  if client is None:
200
  raise ValueError("OpenAI client required when use_hf=False")
 
202
  client, messages, model=model, temperature=0.2, json_mode=True
203
  )
204
 
205
+ return _parse_json_response(raw)
 
 
 
 
 
 
 
 
206
 
207
 
208
  def answer_question(