Spaces:
Running on Zero
Running on Zero
Upload src/veris_classifier/classifier.py with huggingface_hub
Browse files
src/veris_classifier/classifier.py
CHANGED
|
@@ -91,15 +91,28 @@ def load_hf_model():
|
|
| 91 |
|
| 92 |
def _generate_hf(messages: list[dict], max_new_tokens: int = 1024) -> str:
|
| 93 |
"""Generate a response using the fine-tuned HF model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
pipe, tokenizer = load_hf_model()
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
return outputs[0]["generated_text"].strip()
|
| 105 |
|
|
@@ -129,6 +142,32 @@ def _generate_openai(
|
|
| 129 |
return response.choices[0].message.content.strip()
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
|
| 134 |
|
|
@@ -155,7 +194,7 @@ def classify_incident(
|
|
| 155 |
]
|
| 156 |
|
| 157 |
if use_hf:
|
| 158 |
-
raw =
|
| 159 |
else:
|
| 160 |
if client is None:
|
| 161 |
raise ValueError("OpenAI client required when use_hf=False")
|
|
@@ -163,15 +202,7 @@ def classify_incident(
|
|
| 163 |
client, messages, model=model, temperature=0.2, json_mode=True
|
| 164 |
)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
text = raw.strip()
|
| 168 |
-
if text.startswith("```"):
|
| 169 |
-
# Strip ```json ... ``` wrapper
|
| 170 |
-
lines = text.split("\n")
|
| 171 |
-
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
|
| 172 |
-
text = text.strip()
|
| 173 |
-
|
| 174 |
-
return json.loads(text)
|
| 175 |
|
| 176 |
|
| 177 |
def answer_question(
|
|
|
|
| 91 |
|
| 92 |
def _generate_hf(messages: list[dict], max_new_tokens: int = 1024) -> str:
|
| 93 |
"""Generate a response using the fine-tuned HF model."""
|
| 94 |
+
return _generate_hf_with_options(messages, max_new_tokens=max_new_tokens)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _generate_hf_with_options(
|
| 98 |
+
messages: list[dict],
|
| 99 |
+
max_new_tokens: int = 1024,
|
| 100 |
+
do_sample: bool = True,
|
| 101 |
+
temperature: float = 0.2,
|
| 102 |
+
top_p: float = 0.9,
|
| 103 |
+
) -> str:
|
| 104 |
+
"""Generate a response using the fine-tuned HF model with explicit sampling controls."""
|
| 105 |
pipe, tokenizer = load_hf_model()
|
| 106 |
|
| 107 |
+
generate_kwargs = {
|
| 108 |
+
"max_new_tokens": max_new_tokens,
|
| 109 |
+
"do_sample": do_sample,
|
| 110 |
+
}
|
| 111 |
+
if do_sample:
|
| 112 |
+
generate_kwargs["temperature"] = temperature
|
| 113 |
+
generate_kwargs["top_p"] = top_p
|
| 114 |
+
|
| 115 |
+
outputs = pipe(messages, **generate_kwargs)
|
| 116 |
|
| 117 |
return outputs[0]["generated_text"].strip()
|
| 118 |
|
|
|
|
| 142 |
return response.choices[0].message.content.strip()
|
| 143 |
|
| 144 |
|
| 145 |
+
def _parse_json_response(raw: str) -> dict:
|
| 146 |
+
"""Parse model output into JSON with light recovery for wrapped text."""
|
| 147 |
+
text = raw.strip()
|
| 148 |
+
try:
|
| 149 |
+
return json.loads(text)
|
| 150 |
+
except json.JSONDecodeError:
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
if text.startswith("```"):
|
| 154 |
+
lines = text.split("\n")
|
| 155 |
+
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
|
| 156 |
+
text = text.strip()
|
| 157 |
+
try:
|
| 158 |
+
return json.loads(text)
|
| 159 |
+
except json.JSONDecodeError:
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
# Recover when the model prepends/appends prose around a JSON object.
|
| 163 |
+
start = text.find("{")
|
| 164 |
+
end = text.rfind("}")
|
| 165 |
+
if start != -1 and end != -1 and end > start:
|
| 166 |
+
return json.loads(text[start : end + 1])
|
| 167 |
+
|
| 168 |
+
raise json.JSONDecodeError("No JSON object found in model output", text, 0)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 172 |
|
| 173 |
|
|
|
|
| 194 |
]
|
| 195 |
|
| 196 |
if use_hf:
|
| 197 |
+
raw = _generate_hf_with_options(messages, max_new_tokens=1024, do_sample=False)
|
| 198 |
else:
|
| 199 |
if client is None:
|
| 200 |
raise ValueError("OpenAI client required when use_hf=False")
|
|
|
|
| 202 |
client, messages, model=model, temperature=0.2, json_mode=True
|
| 203 |
)
|
| 204 |
|
| 205 |
+
return _parse_json_response(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
def answer_question(
|