| |
| """ |
| Example inference script for CAI-20B Marketing Strategy Expert |
| """ |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import re |
|
|
| class ResponseCleaner: |
| """Clean up model responses to remove artifacts""" |
| |
| def __init__(self): |
| self.artifact_patterns = [ |
| r'<\|[^>]+\|>', |
| r'assistantfinal', |
| r'assistant\s*final', |
| r'We need to.*?(?=\n|$)', |
| r'Let me.*?(?=\n|$)', |
| r'I need to.*?(?=\n|$)', |
| r'\\n\\n\\n+', |
| ] |
| |
| def clean_response(self, text): |
| """Clean artifacts from response""" |
| cleaned = text |
| |
| |
| for pattern in self.artifact_patterns: |
| cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE) |
| |
| |
| cleaned = re.sub(r'\s+', ' ', cleaned).strip() |
| |
| |
| if cleaned and cleaned[-1] not in '.!?': |
| last_sentence = cleaned.split('.')[-1].strip() |
| if len(last_sentence) < 20: |
| parts = cleaned.rsplit('.', 1) |
| if len(parts) > 1: |
| cleaned = parts[0] + '.' |
| else: |
| cleaned += '.' |
| |
| return cleaned |
|
|
|
|
| class CAI20BMarketing: |
| """CAI-20B Marketing Strategy Expert Model""" |
| |
| def __init__(self, model_name="tigres2526/CAI-20B", device="cuda"): |
| """Initialize the model and tokenizer""" |
| print("Loading CAI-20B Marketing Strategy Expert...") |
| |
| self.device = device |
| self.cleaner = ResponseCleaner() |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=True |
| ) |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True |
| ) |
| self.model.eval() |
| |
| print("✅ Model loaded successfully!") |
| |
| def generate( |
| self, |
| question, |
| max_new_tokens=250, |
| temperature=0.7, |
| top_p=0.9, |
| repetition_penalty=1.1, |
| clean_output=True |
| ): |
| """Generate marketing advice for a given question""" |
| |
| |
| prompt = f"""You are a marketing strategy expert specializing in performance marketing, creative development, and conversion optimization. |
| Provide practical, actionable advice grounded in real-world experience. |
| |
| User: {question} |
| Assistant:""" |
| |
| |
| inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=2048 |
| ).to(self.model.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| do_sample=True, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| no_repeat_ngram_size=3, |
| ) |
| |
| |
| response = self.tokenizer.decode( |
| outputs[0][inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True |
| ) |
| |
| |
| if clean_output: |
| response = self.cleaner.clean_response(response) |
| |
| return response |
| |
| def chat(self): |
| """Interactive chat mode""" |
| print("\n" + "=" * 70) |
| print("CAI-20B Marketing Strategy Expert - Interactive Chat") |
| print("Type 'exit' to quit") |
| print("=" * 70 + "\n") |
| |
| while True: |
| user_input = input("You: ").strip() |
| |
| if user_input.lower() == 'exit': |
| print("Goodbye!") |
| break |
| |
| if not user_input: |
| continue |
| |
| |
| response = self.generate(user_input) |
| |
| |
| print(f"\nCAI-20B: {response}\n") |
| print("-" * 70 + "\n") |
|
|
|
|
| def main(): |
| """Example usage""" |
| |
| |
| model = CAI20BMarketing() |
| |
| |
| test_questions = [ |
| "What are the top 3 marketing channels for a B2B SaaS startup?", |
| "How should I allocate a $10K monthly marketing budget?", |
| "What's the difference between CAC and LTV?", |
| ] |
| |
| print("\n" + "=" * 70) |
| print("Running example questions...") |
| print("=" * 70 + "\n") |
| |
| for i, question in enumerate(test_questions, 1): |
| print(f"Q{i}: {question}") |
| response = model.generate(question) |
| print(f"A: {response}\n") |
| print("-" * 50 + "\n") |
| |
| |
| print("\nWould you like to start interactive chat? (y/n)") |
| if input().lower() == 'y': |
| model.chat() |
|
|
|
|
| if __name__ == "__main__": |
| main() |