Sandiago21 commited on
Commit
3d6d0a5
·
verified ·
1 Parent(s): 2b5f879

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -7
app.py CHANGED
@@ -42,6 +42,7 @@ class Config(object):
42
  self.max_len = 256
43
  self.reasoning_max_len = 128
44
  self.temperature = 0.1
 
45
  self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
  self.model_name = "Qwen/Qwen2.5-7B-Instruct"
47
  # self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
@@ -49,6 +50,7 @@ class Config(object):
49
  # self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
50
  # self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
51
 
 
52
  config = Config()
53
 
54
 
@@ -89,6 +91,7 @@ def generate(prompt):
89
  **inputs,
90
  max_new_tokens=config.max_len,
91
  temperature=config.temperature,
 
92
  )
93
 
94
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
@@ -118,12 +121,45 @@ def reasoning_generate(prompt):
118
  **inputs,
119
  max_new_tokens=config.reasoning_max_len,
120
  temperature=config.temperature,
 
121
  )
122
 
123
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
124
 
125
  return reasoning_tokenizer.decode(generated, skip_special_tokens=True).strip()
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  class Action(BaseModel):
128
  tool: str = Field(...)
129
  args: Dict
@@ -476,6 +512,8 @@ Response: <answer>
476
 
477
  DO NOT add anything additional and return ONLY what is asked and in the format asked.
478
 
 
 
479
  ONLY return a response if you are confident about the answer, otherwise return empty string.
480
 
481
  If you output anything else, it is incorrect.
@@ -498,17 +536,37 @@ Information:
498
 
499
  logger.info(f"Raw Output: {raw_output}")
500
 
501
- output = raw_output.split("Response:")[-1].strip()
502
- output = output.split("\n")[0].strip()
503
- # match = re.search(r"Response:\s*(.*)", raw_output, re.IGNORECASE)
504
- # output = match.group(1).strip() if match else ""
 
 
 
505
 
506
- if len(output) > 2 and output[0] == '"' and output[-1] == '"':
507
- output = output[1:-1]
508
 
509
- if len(output) > 2 and output[-1] == '.':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  output = output[:-1]
511
 
 
 
512
  state["output"] = output
513
 
514
  logger.info(f"State (Safety Agent): {state}")
 
42
  self.max_len = 256
43
  self.reasoning_max_len = 128
44
  self.temperature = 0.1
45
+ self.repetition_penalty = 1.2
46
  self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
  self.model_name = "Qwen/Qwen2.5-7B-Instruct"
48
  # self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
 
50
  # self.reasoning_model_name = "Qwen/Qwen2.5-7B-Instruct"
51
  # self.reasoning_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
52
 
53
+
54
  config = Config()
55
 
56
 
 
91
  **inputs,
92
  max_new_tokens=config.max_len,
93
  temperature=config.temperature,
94
+ repetition_penalty = config.repetition_penalty,
95
  )
96
 
97
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
 
121
  **inputs,
122
  max_new_tokens=config.reasoning_max_len,
123
  temperature=config.temperature,
124
+ repetition_penalty = config.repetition_penalty,
125
  )
126
 
127
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
128
 
129
  return reasoning_tokenizer.decode(generated, skip_special_tokens=True).strip()
130
 
131
+
132
+ def reasoning_generate(prompt):
133
+ """
134
+ Generate a text completion from a causal language model given a prompt.
135
+
136
+ Parameters
137
+ ----------
138
+ prompt : str
139
+ Input text prompt used to condition the language model.
140
+
141
+ Returns
142
+ -------
143
+ str
144
+ The generated continuation text, decoded into a string with special
145
+ tokens removed and leading/trailing whitespace stripped.
146
+
147
+ """
148
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
149
+
150
+ with torch.no_grad():
151
+ outputs = model.generate(
152
+ **inputs,
153
+ max_new_tokens=config.reasoning_max_len,
154
+ temperature=config.temperature,
155
+ repetition_penalty = config.repetition_penalty,
156
+ )
157
+
158
+ generated = outputs[0][inputs["input_ids"].shape[-1]:]
159
+
160
+ return tokenizer.decode(generated, skip_special_tokens=True).strip()
161
+
162
+
163
  class Action(BaseModel):
164
  tool: str = Field(...)
165
  args: Dict
 
512
 
513
  DO NOT add anything additional and return ONLY what is asked and in the format asked.
514
 
515
+ If you output anything else, it is incorrect.
516
+
517
  ONLY return a response if you are confident about the answer, otherwise return empty string.
518
 
519
  If you output anything else, it is incorrect.
 
536
 
537
  logger.info(f"Raw Output: {raw_output}")
538
 
539
+ # output = raw_output.split("Response:")[-1].strip()
540
+ # output = output.split("\n")[0].strip()
541
+ # # match = re.search(r"Response:\s*(.*)", raw_output, re.IGNORECASE)
542
+ # # output = match.group(1).strip() if match else ""
543
+
544
+ # if len(output) > 2 and output[0] == '"' and output[-1] == '"':
545
+ # output = output[1:-1]
546
 
547
+ # if len(output) > 2 and output[-1] == '.':
548
+ # output = output[:-1]
549
 
550
+
551
+
552
+ raw = raw_output.strip()
553
+
554
+ # Find the first valid "Response: ..." occurrence
555
+ match = re.search(r"Response:\s*([^\n\.]+)", raw)
556
+
557
+ if match:
558
+ output = match.group(1).strip()
559
+ else:
560
+ # fallback: take first line
561
+ output = raw.split("\n")[0].strip()
562
+
563
+ # Clean quotes / trailing punctuation
564
+ output = output.strip('"').strip()
565
+ if output.endswith("."):
566
  output = output[:-1]
567
 
568
+
569
+
570
  state["output"] = output
571
 
572
  logger.info(f"State (Safety Agent): {state}")