CAI-20B / app_inferless.py
tigres2526's picture
Upload app_inferless.py with huggingface_hub
9b972ab verified
from typing import List, Optional
from transformers import pipeline, AutoTokenizer
import inferless
from pydantic import BaseModel, Field
import re
@inferless.request
class RequestObjects(BaseModel):
prompt: str = Field(default="What is deep learning?")
system_prompt: Optional[str] = "You are a marketing strategy expert. Provide clear, actionable advice."
temperature: Optional[float] = 0.7 # Lower for more focused responses
top_p: Optional[float] = 0.9
top_k: Optional[int] = 50
length_penalty: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.2 # Higher to reduce repetition
stop_strings: Optional[str] = "None"
early_stopping: Optional[bool] = True # Better for production
max_new_tokens: Optional[int] = 250 # Increased for better responses
min_new_tokens: Optional[int] = 50
do_sample: Optional[bool] = True # Enable for more natural responses
num_beams: Optional[int] = 1
min_length: Optional[int] = 0
max_length: Optional[int] = 512
@inferless.response
class ResponseObjects(BaseModel):
role: List[str] = Field(default=["assistant"])
context: List[str] = Field(default=["Response"])
cleaned: Optional[List[str]] = Field(default=None) # Add cleaned version
class ResponseCleaner:
"""Lightweight cleanup for model responses"""
def __init__(self):
# Common artifacts to remove
self.artifact_patterns = [
r'<\|[^>]+\|>', # Special tokens
r'\bassistantfinal\b',
r'\bassistant\s+final\b',
r'\bassistant\b(?!\s*:)',
r'We need to understand[^.:\n]{0,50}[:.]?\s*',
r'We need to[^.:\n]{0,50}[:.]?\s*',
r'I need to[^.:\n]{0,50}[:.]?\s*',
r'Let me[^.:\n]{0,50}[:.]?\s*',
r'According to guidelines[^.:\n]{0,50}[:.]?\s*',
r'The prompt asks[^.:\n]{0,50}[:.]?\s*',
r'The user asks[^.:\n]{0,50}[:.]?\s*',
r'Wait question[^.:\n]{0,50}[:.]?\s*',
r'We must respond[^.:\n]{0,50}[:.]?\s*',
r"Let's produce[^.:\n]{0,50}[:.]?\s*",
r'The answer:[^.:\n]{0,50}[:.]?\s*',
r'Now produce final answer\.?\s*',
r'assistant\s*:\s*',
r'\\n\\n\\n+',
r'\\u[0-9a-fA-F]{4}',
r'^\s*\.+\s*',
r'\s*\.{3,}$',
]
def clean_response(self, text: str) -> str:
"""Remove artifacts from response"""
if not text:
return ""
cleaned = text
# Remove artifacts
for pattern in self.artifact_patterns:
cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE | re.MULTILINE)
# Clean whitespace
cleaned = re.sub(r'\s+', ' ', cleaned)
cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned)
# Remove duplicate words
cleaned = re.sub(r'\b(\w+)\s+\1\b', r'\1', cleaned)
# Fix incomplete endings
cleaned = cleaned.strip()
if cleaned and not cleaned[-1] in '.!?':
if len(cleaned.split('.')[-1].strip()) < 20:
parts = cleaned.rsplit('.', 1)
if len(parts) > 1:
cleaned = parts[0] + '.'
else:
cleaned += '.'
return cleaned
class InferlessPythonModel:
def initialize(self):
"""Initialize model and tokenizer"""
print("Loading CAI-20B Marketing Strategy Expert...")
# Initialize the text generation pipeline
self.generator = pipeline(
"text-generation",
model="tigres2526/CAI-20B", # Corrected model name
device=0, # Use GPU
torch_dtype="auto", # Let it choose appropriate dtype
trust_remote_code=True,
)
# Ensure tokenizer is loaded
if not self.generator.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
"tigres2526/CAI-20B",
trust_remote_code=True,
)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
self.generator.tokenizer = tokenizer
# Initialize response cleaner
self.cleaner = ResponseCleaner()
# Optimized prompt template for marketing expertise
self.system_template = """You are a marketing strategy assistant powered by CAI-20B.
Knowledge cutoff: 2024-06
Current date: 2025-08-07
CRITICAL INSTRUCTIONS:
- Provide ONLY the final answer without any internal reasoning
- NEVER include tokens like <|assistant|>, <|user|>, or similar
- NEVER explain your thought process
- Keep responses concise, professional, and actionable
- Focus on marketing strategy and business growth"""
print("✅ Model ready for inference!")
def infer(self, request: RequestObjects) -> ResponseObjects:
"""Generate response with cleanup"""
# Use optimized system prompt if not provided
system_prompt = request.system_prompt or self.system_template
# Prepare messages based on chat template support
if hasattr(self.generator.tokenizer, 'chat_template') and self.generator.tokenizer.chat_template:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request.prompt},
]
# Apply chat template
formatted_prompt = self.generator.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback to simple format
formatted_prompt = f"{system_prompt}\n\nUser: {request.prompt}\nAssistant:"
# Generate with optimized parameters for production
pipeline_output = self.generator(
formatted_prompt,
max_new_tokens=int(request.max_new_tokens),
min_new_tokens=int(request.min_new_tokens),
temperature=float(request.temperature),
top_p=float(request.top_p),
top_k=int(request.top_k),
repetition_penalty=float(request.repetition_penalty),
length_penalty=float(request.length_penalty),
do_sample=bool(request.do_sample),
num_beams=int(request.num_beams),
early_stopping=bool(request.early_stopping),
pad_token_id=self.generator.tokenizer.pad_token_id,
eos_token_id=self.generator.tokenizer.eos_token_id,
# Remove stop_strings if not supported
# stop_strings=request.stop_strings if request.stop_strings != "None" else None,
)
# Extract generated text
generated_text = pipeline_output[0]["generated_text"]
# Remove the input prompt from the output if present
if isinstance(generated_text, str) and formatted_prompt in generated_text:
generated_text = generated_text.replace(formatted_prompt, "").strip()
# Clean the response
cleaned_text = self.cleaner.clean_response(generated_text)
# Return both original and cleaned versions
if isinstance(generated_text, list):
# Chat-style output
roles = [item.get("role", "") for item in generated_text]
contexts = [item.get("content", "") for item in generated_text]
cleaned_contexts = [self.cleaner.clean_response(ctx) for ctx in contexts]
return ResponseObjects(
role=roles,
context=contexts,
cleaned=cleaned_contexts
)
else:
# Plain text output
return ResponseObjects(
role=["assistant"],
context=[generated_text],
cleaned=[cleaned_text]
)
def finalize(self):
"""Cleanup resources"""
self.generator = None
self.cleaner = None
# For local testing
if __name__ == "__main__":
# Test the model locally
model = InferlessPythonModel()
model.initialize()
# Test request
test_request = RequestObjects(
prompt="What are the top 3 marketing channels for B2B SaaS?",
temperature=0.7,
max_new_tokens=200,
do_sample=True
)
response = model.infer(test_request)
print(f"Response: {response.cleaned[0] if response.cleaned else response.context[0]}")
model.finalize()