tigres2526 commited on
Commit
ff065a4
·
verified ·
1 Parent(s): f621767

Add production utilities for artifact cleanup

Browse files
Files changed (1) hide show
  1. cai_20b_utils.py +368 -0
cai_20b_utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CAI-20B Utils - Production utilities for the CAI-20B Marketing Strategy Expert model
4
+ """
5
+
6
+ import re
7
+ import torch
8
+ from typing import Optional, Dict, Any
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+
11
+
12
+ class ResponseCleaner:
13
+ """Clean up model responses to remove artifacts and formatting issues"""
14
+
15
+ def __init__(self):
16
+ # Common artifacts to remove
17
+ self.artifact_patterns = [
18
+ r'<\|[^>]+\|>', # Special tokens like <|assistant|>
19
+ r'assistantfinal',
20
+ r'assistant\s*final',
21
+ r'\bassistant\b(?![\w\s]*:)',
22
+ r'We need to understand:.*?(?=\n|$)',
23
+ r'We need to.*?(?=\n|$)',
24
+ r'I need to.*?(?=\n|$)',
25
+ r'Let me.*?(?=\n|$)',
26
+ r'According to guidelines.*?(?=\n|$)',
27
+ r'The prompt asks.*?(?=\n|$)',
28
+ r'The user asks.*?(?=\n|$)',
29
+ r'Wait question.*?(?=\n|$)',
30
+ r'We must respond.*?(?=\n|$)',
31
+ r"Let's produce.*?(?=\n|$)",
32
+ r'The answer:.*?(?=\n|$)',
33
+ r'The conversation ends.*?(?=\n|$)',
34
+ r'\\n\\n\\n+', # Multiple newlines
35
+ r'\\u[0-9a-fA-F]{4}', # Unicode escapes
36
+ ]
37
+
38
+ # Pattern for detecting repetition
39
+ self.repetition_pattern = r'(.{10,}?)\1{2,}'
40
+
41
+ # Patterns for incomplete endings
42
+ self.incomplete_patterns = [
43
+ r'\.{3,}$', # Trailing ellipsis
44
+ r'\s+\.\s*$', # Trailing period with spaces
45
+ r'\s+$', # Trailing spaces
46
+ r'^\s+', # Leading spaces
47
+ ]
48
+
49
+ def clean_response(self, text: str) -> str:
50
+ """Main cleaning function - removes all artifacts"""
51
+ if not text:
52
+ return ""
53
+
54
+ # Step 1: Remove artifacts
55
+ cleaned = self.clean_artifacts(text)
56
+
57
+ # Step 2: Fix repetitions
58
+ cleaned = self.fix_repetitions(cleaned)
59
+
60
+ # Step 3: Fix incomplete endings
61
+ cleaned = self.fix_incomplete_endings(cleaned)
62
+
63
+ # Step 4: Ensure minimum quality
64
+ cleaned = self.ensure_minimum_quality(cleaned)
65
+
66
+ return cleaned if cleaned else text
67
+
68
+ def clean_artifacts(self, text: str) -> str:
69
+ """Remove known artifacts from response"""
70
+ cleaned = text
71
+
72
+ for pattern in self.artifact_patterns:
73
+ cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE | re.MULTILINE)
74
+
75
+ # Clean up excessive whitespace
76
+ cleaned = re.sub(r'\s+', ' ', cleaned)
77
+ cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned)
78
+
79
+ return cleaned.strip()
80
+
81
+ def fix_repetitions(self, text: str) -> str:
82
+ """Fix repetitive segments in text"""
83
+ def replace_repetition(match):
84
+ return match.group(1)
85
+
86
+ cleaned = re.sub(self.repetition_pattern, replace_repetition, text)
87
+
88
+ # Remove duplicate words
89
+ cleaned = re.sub(r'\b(\w+)\s+\1\b', r'\1', cleaned)
90
+
91
+ return cleaned
92
+
93
+ def fix_incomplete_endings(self, text: str) -> str:
94
+ """Fix incomplete or trailing endings"""
95
+ cleaned = text
96
+
97
+ # Remove incomplete patterns
98
+ for pattern in self.incomplete_patterns:
99
+ cleaned = re.sub(pattern, '', cleaned)
100
+
101
+ # Ensure proper ending punctuation
102
+ if cleaned and not cleaned[-1] in '.!?':
103
+ last_sentence = cleaned.split('.')[-1].strip()
104
+ if len(last_sentence) < 20:
105
+ parts = cleaned.rsplit('.', 1)
106
+ if len(parts) > 1:
107
+ cleaned = parts[0] + '.'
108
+ else:
109
+ cleaned += '.'
110
+
111
+ return cleaned
112
+
113
+ def ensure_minimum_quality(self, text: str, min_length: int = 50) -> Optional[str]:
114
+ """Ensure response meets minimum quality standards"""
115
+ if len(text.strip()) < min_length:
116
+ return None
117
+
118
+ # Check for too many special characters
119
+ special_char_ratio = len(re.findall(r'[^a-zA-Z0-9\s.,!?;:\'\"-]', text)) / max(len(text), 1)
120
+ if special_char_ratio > 0.3:
121
+ return None
122
+
123
+ # Check for coherent sentences
124
+ sentences = re.split(r'[.!?]+', text)
125
+ complete_sentences = [s for s in sentences if len(s.strip()) > 10]
126
+ if len(complete_sentences) < 1:
127
+ return None
128
+
129
+ return text
130
+
131
+
132
+ class StrictPromptTemplate:
133
+ """Strict prompt templates to prevent artifacts"""
134
+
135
+ SYSTEM_PROMPT = """You are a marketing strategy assistant powered by gpt-oss.
136
+ Knowledge cutoff: 2024-06
137
+ Current date: 2025-08-06
138
+
139
+ CRITICAL INSTRUCTIONS:
140
+ - Provide ONLY the final answer without any internal reasoning
141
+ - NEVER include tokens like <|assistant|>, <|user|>, or similar
142
+ - NEVER explain your thought process or what you're doing
143
+ - NEVER use phrases like "We need to", "Let me", "I need to"
144
+ - NEVER repeat words or phrases
145
+ - Always end responses properly with punctuation
146
+ - Keep responses concise and professional"""
147
+
148
+ DEVELOPER_PROMPT = """# Response Requirements
149
+ - Output ONLY the final response to the user
150
+ - NO internal dialogue or reasoning exposition
151
+ - NO meta-commentary about the task
152
+ - NO repetitive text or loops
153
+ - Must be complete, coherent sentences
154
+ - Professional marketing expertise only
155
+ - If uncertain, provide best practice guidance
156
+ - Format: Direct, actionable advice"""
157
+
158
+ @classmethod
159
+ def format_prompt(cls, user_message: str) -> str:
160
+ """Format a user message with strict anti-artifact prompting"""
161
+ return f"""{cls.SYSTEM_PROMPT}
162
+
163
+ {cls.DEVELOPER_PROMPT}
164
+
165
+ User: {user_message}
166
+ Assistant:"""
167
+
168
+
169
+ class CAI20B:
170
+ """Production-ready wrapper for CAI-20B Marketing Strategy Expert"""
171
+
172
+ def __init__(
173
+ self,
174
+ model_name: str = "tigres2526/CAI-20B",
175
+ device: str = "auto",
176
+ torch_dtype = torch.bfloat16,
177
+ trust_remote_code: bool = True
178
+ ):
179
+ """Initialize the model with production settings"""
180
+ print("Loading CAI-20B Marketing Strategy Expert...")
181
+
182
+ self.device = device
183
+ self.cleaner = ResponseCleaner()
184
+ self.prompt_template = StrictPromptTemplate()
185
+
186
+ # Load tokenizer
187
+ self.tokenizer = AutoTokenizer.from_pretrained(
188
+ model_name,
189
+ trust_remote_code=trust_remote_code
190
+ )
191
+ if not self.tokenizer.pad_token:
192
+ self.tokenizer.pad_token = self.tokenizer.eos_token
193
+
194
+ # Load model
195
+ self.model = AutoModelForCausalLM.from_pretrained(
196
+ model_name,
197
+ device_map=device,
198
+ torch_dtype=torch_dtype,
199
+ trust_remote_code=trust_remote_code
200
+ )
201
+ self.model.eval()
202
+
203
+ print("✅ Model ready for production use!")
204
+
205
+ def generate(
206
+ self,
207
+ user_message: str,
208
+ max_new_tokens: int = 250,
209
+ temperature: float = 0.7,
210
+ top_p: float = 0.9,
211
+ repetition_penalty: float = 1.1,
212
+ no_repeat_ngram_size: int = 3,
213
+ do_sample: bool = True,
214
+ clean_output: bool = True,
215
+ retry_on_artifacts: bool = True,
216
+ max_retries: int = 2
217
+ ) -> str:
218
+ """Generate a clean response to user message"""
219
+
220
+ # Format prompt with strict template
221
+ prompt = self.prompt_template.format_prompt(user_message)
222
+
223
+ # Try generation with retries
224
+ for attempt in range(max_retries):
225
+ # Adjust parameters for retries
226
+ if attempt > 0:
227
+ temperature = max(0.5, temperature - 0.1)
228
+ repetition_penalty = min(1.5, repetition_penalty + 0.1)
229
+
230
+ # Generate response
231
+ response = self._generate_raw(
232
+ prompt,
233
+ max_new_tokens=max_new_tokens,
234
+ temperature=temperature,
235
+ top_p=top_p,
236
+ repetition_penalty=repetition_penalty,
237
+ no_repeat_ngram_size=no_repeat_ngram_size,
238
+ do_sample=do_sample
239
+ )
240
+
241
+ # Clean if requested
242
+ if clean_output:
243
+ response = self.cleaner.clean_response(response)
244
+
245
+ # Check for artifacts
246
+ if retry_on_artifacts and self._has_artifacts(response):
247
+ if attempt < max_retries - 1:
248
+ print(f"⚠️ Artifacts detected, retrying... (attempt {attempt + 2}/{max_retries})")
249
+ continue
250
+
251
+ return response
252
+
253
+ # Final fallback
254
+ return response if response else "I can help with marketing strategy questions. Please try rephrasing your question."
255
+
256
+ def _generate_raw(
257
+ self,
258
+ prompt: str,
259
+ max_new_tokens: int,
260
+ temperature: float,
261
+ top_p: float,
262
+ repetition_penalty: float,
263
+ no_repeat_ngram_size: int,
264
+ do_sample: bool
265
+ ) -> str:
266
+ """Internal method for raw generation"""
267
+ inputs = self.tokenizer(
268
+ prompt,
269
+ return_tensors="pt",
270
+ truncation=True,
271
+ max_length=2048
272
+ )
273
+
274
+ if self.device != "auto":
275
+ inputs = inputs.to(self.device)
276
+
277
+ with torch.no_grad():
278
+ outputs = self.model.generate(
279
+ **inputs,
280
+ max_new_tokens=max_new_tokens,
281
+ temperature=temperature,
282
+ top_p=top_p,
283
+ repetition_penalty=repetition_penalty,
284
+ no_repeat_ngram_size=no_repeat_ngram_size,
285
+ do_sample=do_sample,
286
+ pad_token_id=self.tokenizer.pad_token_id,
287
+ eos_token_id=self.tokenizer.eos_token_id,
288
+ early_stopping=True
289
+ )
290
+
291
+ response = self.tokenizer.decode(
292
+ outputs[0][inputs['input_ids'].shape[1]:],
293
+ skip_special_tokens=True
294
+ )
295
+
296
+ return response
297
+
298
+ def _has_artifacts(self, text: str) -> bool:
299
+ """Check if response has artifacts"""
300
+ if not text or len(text.strip()) < 50:
301
+ return True
302
+
303
+ artifact_indicators = [
304
+ "we need to", "let me", "<|", "|>",
305
+ "assistant", "...", " ", "according to guidelines",
306
+ "the prompt asks", "wait question"
307
+ ]
308
+
309
+ text_lower = text.lower()
310
+ for indicator in artifact_indicators:
311
+ if indicator in text_lower:
312
+ return True
313
+
314
+ return False
315
+
316
+ def chat(self):
317
+ """Interactive chat mode"""
318
+ print("\n" + "=" * 70)
319
+ print("CAI-20B Marketing Strategy Expert - Interactive Chat")
320
+ print("Type 'exit' to quit, 'clear' to reset conversation")
321
+ print("=" * 70 + "\n")
322
+
323
+ while True:
324
+ user_input = input("You: ").strip()
325
+
326
+ if user_input.lower() == 'exit':
327
+ print("Goodbye!")
328
+ break
329
+
330
+ if user_input.lower() == 'clear':
331
+ print("Conversation cleared.\n")
332
+ continue
333
+
334
+ if not user_input:
335
+ continue
336
+
337
+ response = self.generate(user_input)
338
+ print(f"\nCAI-20B: {response}\n")
339
+ print("-" * 70 + "\n")
340
+
341
+
342
+ # Convenience function for quick usage
343
+ def quick_generate(question: str, model_name: str = "tigres2526/CAI-20B") -> str:
344
+ """Quick one-off generation without keeping model in memory"""
345
+ model = CAI20B(model_name)
346
+ return model.generate(question)
347
+
348
+
349
+ if __name__ == "__main__":
350
+ # Example usage
351
+ print("Testing CAI-20B Marketing Strategy Expert...")
352
+
353
+ # Initialize model
354
+ model = CAI20B()
355
+
356
+ # Test questions
357
+ test_questions = [
358
+ "What are the top 3 marketing channels for a B2B SaaS startup?",
359
+ "How should I allocate a $10K monthly marketing budget?",
360
+ "What's the difference between CAC and LTV?"
361
+ ]
362
+
363
+ print("\nRunning test questions:\n")
364
+ for question in test_questions:
365
+ print(f"Q: {question}")
366
+ response = model.generate(question)
367
+ print(f"A: {response}\n")
368
+ print("-" * 50 + "\n")