Update app.py
Browse files
app.py
CHANGED
|
@@ -42,6 +42,7 @@ class Config(object):
|
|
| 42 |
self.max_len = 256
|
| 43 |
self.reasoning_max_len = 128
|
| 44 |
self.temperature = 0.1
|
|
|
|
| 45 |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
self.model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 47 |
# self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
|
@@ -49,6 +50,7 @@ class Config(object):
|
|
| 49 |
# self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 50 |
# self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 51 |
|
|
|
|
| 52 |
config = Config()
|
| 53 |
|
| 54 |
|
|
@@ -89,6 +91,7 @@ def generate(prompt):
|
|
| 89 |
**inputs,
|
| 90 |
max_new_tokens=config.max_len,
|
| 91 |
temperature=config.temperature,
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
|
@@ -118,12 +121,45 @@ def reasoning_generate(prompt):
|
|
| 118 |
**inputs,
|
| 119 |
max_new_tokens=config.reasoning_max_len,
|
| 120 |
temperature=config.temperature,
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 124 |
|
| 125 |
return reasoning_tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
class Action(BaseModel):
|
| 128 |
tool: str = Field(...)
|
| 129 |
args: Dict
|
|
@@ -476,6 +512,8 @@ Response: <answer>
|
|
| 476 |
|
| 477 |
DO NOT add anything additional and return ONLY what is asked and in the format asked.
|
| 478 |
|
|
|
|
|
|
|
| 479 |
ONLY return a response if you are confident about the answer, otherwise return empty string.
|
| 480 |
|
| 481 |
If you output anything else, it is incorrect.
|
|
@@ -498,17 +536,37 @@ Information:
|
|
| 498 |
|
| 499 |
logger.info(f"Raw Output: {raw_output}")
|
| 500 |
|
| 501 |
-
output = raw_output.split("Response:")[-1].strip()
|
| 502 |
-
output = output.split("\n")[0].strip()
|
| 503 |
-
# match = re.search(r"Response:\s*(.*)", raw_output, re.IGNORECASE)
|
| 504 |
-
# output = match.group(1).strip() if match else ""
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
-
if len(output) > 2 and output[
|
| 507 |
-
|
| 508 |
|
| 509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
output = output[:-1]
|
| 511 |
|
|
|
|
|
|
|
| 512 |
state["output"] = output
|
| 513 |
|
| 514 |
logger.info(f"State (Safety Agent): {state}")
|
|
|
|
| 42 |
self.max_len = 256
|
| 43 |
self.reasoning_max_len = 128
|
| 44 |
self.temperature = 0.1
|
| 45 |
+
self.repetition_penalty = 1.2
|
| 46 |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
self.model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 48 |
# self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
|
|
|
|
| 50 |
# self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
|
| 51 |
# self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
| 52 |
|
| 53 |
+
|
| 54 |
config = Config()
|
| 55 |
|
| 56 |
|
|
|
|
| 91 |
**inputs,
|
| 92 |
max_new_tokens=config.max_len,
|
| 93 |
temperature=config.temperature,
|
| 94 |
+
repetition_penalty = config.repetition_penalty,
|
| 95 |
)
|
| 96 |
|
| 97 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
|
|
|
| 121 |
**inputs,
|
| 122 |
max_new_tokens=config.reasoning_max_len,
|
| 123 |
temperature=config.temperature,
|
| 124 |
+
repetition_penalty = config.repetition_penalty,
|
| 125 |
)
|
| 126 |
|
| 127 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 128 |
|
| 129 |
return reasoning_tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 130 |
|
| 131 |
+
|
| 132 |
+
def reasoning_generate(prompt):
|
| 133 |
+
"""
|
| 134 |
+
Generate a text completion from a causal language model given a prompt.
|
| 135 |
+
|
| 136 |
+
Parameters
|
| 137 |
+
----------
|
| 138 |
+
prompt : str
|
| 139 |
+
Input text prompt used to condition the language model.
|
| 140 |
+
|
| 141 |
+
Returns
|
| 142 |
+
-------
|
| 143 |
+
str
|
| 144 |
+
The generated continuation text, decoded into a string with special
|
| 145 |
+
tokens removed and leading/trailing whitespace stripped.
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 149 |
+
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
outputs = model.generate(
|
| 152 |
+
**inputs,
|
| 153 |
+
max_new_tokens=config.reasoning_max_len,
|
| 154 |
+
temperature=config.temperature,
|
| 155 |
+
repetition_penalty = config.repetition_penalty,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 159 |
+
|
| 160 |
+
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
class Action(BaseModel):
|
| 164 |
tool: str = Field(...)
|
| 165 |
args: Dict
|
|
|
|
| 512 |
|
| 513 |
DO NOT add anything additional and return ONLY what is asked and in the format asked.
|
| 514 |
|
| 515 |
+
If you output anything else, it is incorrect.
|
| 516 |
+
|
| 517 |
ONLY return a response if you are confident about the answer, otherwise return empty string.
|
| 518 |
|
| 519 |
If you output anything else, it is incorrect.
|
|
|
|
| 536 |
|
| 537 |
logger.info(f"Raw Output: {raw_output}")
|
| 538 |
|
| 539 |
+
# output = raw_output.split("Response:")[-1].strip()
|
| 540 |
+
# output = output.split("\n")[0].strip()
|
| 541 |
+
# # match = re.search(r"Response:\s*(.*)", raw_output, re.IGNORECASE)
|
| 542 |
+
# # output = match.group(1).strip() if match else ""
|
| 543 |
+
|
| 544 |
+
# if len(output) > 2 and output[0] == '"' and output[-1] == '"':
|
| 545 |
+
# output = output[1:-1]
|
| 546 |
|
| 547 |
+
# if len(output) > 2 and output[-1] == '.':
|
| 548 |
+
# output = output[:-1]
|
| 549 |
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
raw = raw_output.strip()
|
| 553 |
+
|
| 554 |
+
# Find the first valid "Response: ..." occurrence
|
| 555 |
+
match = re.search(r"Response:\s*([^\n\.]+)", raw)
|
| 556 |
+
|
| 557 |
+
if match:
|
| 558 |
+
output = match.group(1).strip()
|
| 559 |
+
else:
|
| 560 |
+
# fallback: take first line
|
| 561 |
+
output = raw.split("\n")[0].strip()
|
| 562 |
+
|
| 563 |
+
# Clean quotes / trailing punctuation
|
| 564 |
+
output = output.strip('"').strip()
|
| 565 |
+
if output.endswith("."):
|
| 566 |
output = output[:-1]
|
| 567 |
|
| 568 |
+
|
| 569 |
+
|
| 570 |
state["output"] = output
|
| 571 |
|
| 572 |
logger.info(f"State (Safety Agent): {state}")
|