CAI-20B / app.py
tigres2526's picture
Update app.py to use pipeline API for Inferless compatibility
93afb6b verified
from typing import List, Optional
from transformers import pipeline, AutoTokenizer
import inferless
from pydantic import BaseModel, Field
import re
@inferless.request
class RequestObjects(BaseModel):
prompt: str = Field(default="What is deep learning?")
system_prompt: Optional[str] = "You are a marketing strategy expert. Provide clear, actionable advice."
temperature: Optional[float] = 0.7 # Lower for more focused responses
top_p: Optional[float] = 0.9
top_k: Optional[int] = 50
length_penalty: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.2 # Higher to reduce repetition
stop_strings: Optional[str] = "None"
early_stopping: Optional[bool] = True # Better for production
max_new_tokens: Optional[int] = 250 # Increased for better responses
min_new_tokens: Optional[int] = 50
do_sample: Optional[bool] = True # Enable for more natural responses
num_beams: Optional[int] = 1
min_length: Optional[int] = 0
max_length: Optional[int] = 512
@inferless.response
class ResponseObjects(BaseModel):
role: List[str] = Field(default=["assistant"])
context: List[str] = Field(default=["Response"])
cleaned: Optional[List[str]] = Field(default=None) # Add cleaned version
class ResponseCleaner:
"""Lightweight cleanup for model responses"""
def __init__(self):
# Common artifacts to remove
self.artifact_patterns = [
r'<\|[^>]+\|>', # Special tokens
r'\bassistantfinal\b',
r'\bassistant\s+final\b',
r'\bassistant\b(?!\s*:)',
r'We need to understand[^.:\n]{0,50}[:.]?\s*',
r'We need to[^.:\n]{0,50}[:.]?\s*',
r'I need to[^.:\n]{0,50}[:.]?\s*',
r'Let me[^.:\n]{0,50}[:.]?\s*',
r'According to guidelines[^.:\n]{0,50}[:.]?\s*',
r'The prompt asks[^.:\n]{0,50}[:.]?\s*',
r'The user asks[^.:\n]{0,50}[:.]?\s*',
r'Wait question[^.:\n]{0,50}[:.]?\s*',
r'We must respond[^.:\n]{0,50}[:.]?\s*',
r"Let's produce[^.:\n]{0,50}[:.]?\s*",
r'The answer:[^.:\n]{0,50}[:.]?\s*',
r'Now produce final answer\.?\s*',
r'assistant\s*:\s*',
r'\\n\\n\\n+',
r'\\u[0-9a-fA-F]{4}',
r'^\s*\.+\s*',
r'\s*\.{3,}$',
]
def clean_response(self, text: str) -> str:
"""Remove artifacts from response"""
if not text:
return ""
cleaned = text
# Remove artifacts
for pattern in self.artifact_patterns:
cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE | re.MULTILINE)
# Clean whitespace
cleaned = re.sub(r'\s+', ' ', cleaned)
cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned)
# Remove duplicate words
cleaned = re.sub(r'\b(\w+)\s+\1\b', r'\1', cleaned)
# Fix incomplete endings
cleaned = cleaned.strip()
if cleaned and not cleaned[-1] in '.!?':
if len(cleaned.split('.')[-1].strip()) < 20:
parts = cleaned.rsplit('.', 1)
if len(parts) > 1:
cleaned = parts[0] + '.'
else:
cleaned += '.'
return cleaned
class InferlessPythonModel:
def initialize(self):
"""Initialize model and tokenizer"""
print("Loading CAI-20B Marketing Strategy Expert...")
# Initialize the text generation pipeline
self.generator = pipeline(
"text-generation",
model="tigres2526/CAI-20B", # Corrected model name
device=0, # Use GPU
torch_dtype="auto", # Let it choose appropriate dtype
trust_remote_code=True,
)
# Ensure tokenizer is loaded
if not self.generator.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
"tigres2526/CAI-20B",
trust_remote_code=True,
)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
self.generator.tokenizer = tokenizer
# Initialize response cleaner
self.cleaner = ResponseCleaner()
# Optimized prompt template for marketing expertise
self.system_template = """You are a marketing strategy assistant powered by CAI-20B.
Knowledge cutoff: 2024-06
Current date: 2025-08-07
CRITICAL INSTRUCTIONS:
- Provide ONLY the final answer without any internal reasoning
- NEVER include tokens like <|assistant|>, <|user|>, or similar
- NEVER explain your thought process
- Keep responses concise, professional, and actionable
- Focus on marketing strategy and business growth"""
print("✅ Model ready for inference!")
def infer(self, request: RequestObjects) -> ResponseObjects:
"""Generate response with cleanup"""
# Use optimized system prompt if not provided
system_prompt = request.system_prompt or self.system_template
# Prepare messages based on chat template support
if hasattr(self.generator.tokenizer, 'chat_template') and self.generator.tokenizer.chat_template:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request.prompt},
]
# Apply chat template
formatted_prompt = self.generator.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback to simple format
formatted_prompt = f"{system_prompt}\n\nUser: {request.prompt}\nAssistant:"
# Generate with optimized parameters for production
pipeline_output = self.generator(
formatted_prompt,
max_new_tokens=int(request.max_new_tokens),
min_new_tokens=int(request.min_new_tokens),
temperature=float(request.temperature),
top_p=float(request.top_p),
top_k=int(request.top_k),
repetition_penalty=float(request.repetition_penalty),
length_penalty=float(request.length_penalty),
do_sample=bool(request.do_sample),
num_beams=int(request.num_beams),
early_stopping=bool(request.early_stopping),
pad_token_id=self.generator.tokenizer.pad_token_id,
eos_token_id=self.generator.tokenizer.eos_token_id,
# Remove stop_strings if not supported
# stop_strings=request.stop_strings if request.stop_strings != "None" else None,
)
# Extract generated text
generated_text = pipeline_output[0]["generated_text"]
# Remove the input prompt from the output if present
if isinstance(generated_text, str) and formatted_prompt in generated_text:
generated_text = generated_text.replace(formatted_prompt, "").strip()
# Clean the response
cleaned_text = self.cleaner.clean_response(generated_text)
# Return both original and cleaned versions
if isinstance(generated_text, list):
# Chat-style output
roles = [item.get("role", "") for item in generated_text]
contexts = [item.get("content", "") for item in generated_text]
cleaned_contexts = [self.cleaner.clean_response(ctx) for ctx in contexts]
return ResponseObjects(
role=roles,
context=contexts,
cleaned=cleaned_contexts
)
else:
# Plain text output
return ResponseObjects(
role=["assistant"],
context=[generated_text],
cleaned=[cleaned_text]
)
def finalize(self):
"""Cleanup resources"""
self.generator = None
self.cleaner = None
# For local testing
if __name__ == "__main__":
# Test the model locally
model = InferlessPythonModel()
model.initialize()
# Test request
test_request = RequestObjects(
prompt="What are the top 3 marketing channels for B2B SaaS?",
temperature=0.7,
max_new_tokens=200,
do_sample=True
)
response = model.infer(test_request)
print(f"Response: {response.cleaned[0] if response.cleaned else response.context[0]}")
model.finalize()