suno / app.py
Stanley03's picture
Update app.py
4f06181 verified
"""
STANLEY AI - Optimized Flask Backend
Deploy on Hugging Face Spaces with fast, smaller models
"""
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
import time
import re
import logging
from threading import Thread
import queue
import io
import base64
import random
from PIL import Image, ImageDraw, ImageFont
import os
import gc
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# ============================================================================
# MODEL CONFIGURATION - OPTIMIZED FOR SPEED
# ============================================================================
MODEL_CONFIG = {
"primary": "Qwen/Qwen2.5-1.8B-Instruct", # Fast, multilingual, good balance
"fallback": "microsoft/Phi-3-mini-4k-instruct", # Ultra-fast alternative
"tiny": "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # For minimal memory usage
}
model = None
tokenizer = None
model_loaded = False
current_model_name = None
# Performance cache
response_cache = {}
CACHE_SIZE = 200
# System Prompt (optimized for speed)
STANLEY_AI_SYSTEM = """You are STANLEY AI - an advanced assistant with Kiswahili cultural knowledge.
Provide helpful, concise responses. Integrate Kiswahili phrases naturally when relevant.
Key capabilities:
- Answer questions knowledgeably
- Use Kiswahili for greetings, proverbs, and cultural references
- Explain concepts clearly
- Be efficient and to the point
Format: Use **bold** for emphasis. Keep responses under 300 words unless detailed explanation is needed."""
# Simple Kiswahili knowledge base (replaces external file)
KISWAHILI_KNOWLEDGE = {
"greetings": {
"hello": "Jambo / Habari",
"how_are_you": "Habari yako?",
"goodbye": "Kwaheri / Tuonane tena",
"thank_you": "Asante sana",
"welcome": "Karibu / Karibuni"
},
"proverbs": [
"Mwenye pupa hadiriki kula tamu - The impatient one misses sweet things.",
"Asiyefunzwa na mamae hufunzwa na ulimwengu - He who is not taught by his mother is taught by the world.",
"Haraka haraka haina baraka - Hurry hurry has no blessing.",
"Ukitaka kwenda haraka, nenda peke yako. Ukitaka kwenda mbali, nenda na wenzako - If you want to go fast, go alone. If you want to go far, go together."
],
"lion_king": {
"simba": "Lion (the main character)",
"rafiki": "Friend (the wise baboon)",
"hakuna_matata": "No worries / No problems",
"mufasa": "Simba's father, the king",
"nala": "Simba's childhood friend and queen"
}
}
def load_model_optimized(model_name=None):
"""Load model with optimizations for Hugging Face Spaces"""
global model, tokenizer, model_loaded, current_model_name
if model_loaded and model_name == current_model_name:
return
# Choose model
if not model_name:
model_name = MODEL_CONFIG["primary"]
logger.info(f"🚀 Loading model: {model_name}")
try:
# Clear previous model from memory
if model is not None:
del model
del tokenizer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=True # Fast tokenizer for speed
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with 4-bit quantization for speed and memory efficiency
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_4bit=True, # 4-bit quantization for speed
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.eval() # Set to evaluation mode
model_loaded = True
current_model_name = model_name
# Pre-warm model with a simple prompt
prewarm_model()
logger.info(f"✅ Model loaded successfully: {model_name}")
logger.info(f"📊 Model device: {model.device}")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
# Try fallback
if model_name != MODEL_CONFIG["fallback"]:
logger.info("🔄 Trying fallback model...")
return load_model_optimized(MODEL_CONFIG["fallback"])
else:
logger.error("❌ All models failed to load")
model_loaded = False
return False
def prewarm_model():
"""Generate a dummy response to warm up the model"""
try:
dummy_input = "Hello, STANLEY AI!"
messages = [
{"role": "system", "content": "Say hello briefly."},
{"role": "user", "content": dummy_input}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
_ = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False
)
logger.info("✅ Model pre-warmed successfully!")
except Exception as e:
logger.warning(f"Pre-warm failed: {e}")
def detect_kiswahili_context(text):
"""Detect if text contains Kiswahili or cultural references"""
if not text:
return False
text_lower = text.lower()
kiswahili_keywords = [
'swahili', 'kiswahili', 'hakuna', 'matata', 'asante', 'rafiki',
'jambo', 'mambo', 'pole', 'sawa', 'karibu', 'kwaheri', 'simba',
'lion king', 'mufasa', 'nala', 'kenya', 'tanzania', 'africa',
'habari', 'nze', 'pumbaa', 'timon', 'safari', 'ujamaa'
]
return any(keyword in text_lower for keyword in kiswahili_keywords)
def enhance_with_kiswahili(response, user_message):
"""Add Kiswahili elements to response"""
if detect_kiswahili_context(user_message):
# Add a Kiswahili greeting or phrase
greetings = list(KISWAHILI_KNOWLEDGE["greetings"].values())
greeting = random.choice(greetings)
# Add a proverb if appropriate
if any(word in user_message.lower() for word in ['advice', 'wisdom', 'lesson', 'teach']):
proverb = random.choice(KISWAHILI_KNOWLEDGE["proverbs"])
enhanced = f"{greeting}! {response}\n\n**🔥 Kiswahili Proverb:** {proverb}"
else:
enhanced = f"{greeting}! {response}"
# Add Lion King reference if relevant
if any(word in user_message.lower() for word in ['lion', 'simba', 'mufasa', 'disney']):
lion_fact = "Did you know? 'Simba' means lion in Kiswahili, and 'Rafiki' means friend!"
enhanced += f"\n\n{lion_fact}"
return enhanced
return response
def get_cached_response(user_message):
"""Get response from cache"""
cache_key = user_message.lower().strip()[:80]
return response_cache.get(cache_key)
def set_cached_response(user_message, response):
"""Cache response"""
cache_key = user_message.lower().strip()[:80]
if len(response_cache) >= CACHE_SIZE:
# Remove random item to make space
random_key = random.choice(list(response_cache.keys()))
del response_cache[random_key]
response_cache[cache_key] = response
def generate_response(user_message, max_tokens=512):
"""Generate optimized response"""
# Check cache
cached = get_cached_response(user_message)
if cached:
logger.info("📦 Using cached response")
return cached
# Ensure model is loaded
if not model_loaded:
success = load_model_optimized()
if not success:
return "I'm still initializing. Please try again in a moment."
# Prepare messages
messages = [
{"role": "system", "content": STANLEY_AI_SYSTEM},
{"role": "user", "content": user_message}
]
try:
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(text, return_tensors="pt").to(model.device)
# Generate with optimized settings
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
top_p=0.9,
top_k=40,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
early_stopping=True
)
# Decode response
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip()
# Enhance with Kiswahili
enhanced_response = enhance_with_kiswahili(response, user_message)
# Cache
set_cached_response(user_message, enhanced_response)
return enhanced_response
except Exception as e:
logger.error(f"Generation error: {e}")
return f"Pole! I encountered an error: {str(e)[:100]}"
def generate_image_simple(prompt, width=512, height=512):
"""Simple image generation using PIL (no external dependencies)"""
try:
# Create base image with gradient
img = Image.new('RGB', (width, height), color='white')
draw = ImageDraw.Draw(img)
# Create a simple gradient or pattern
for i in range(height):
r = int(100 + 155 * i / height)
g = int(150 + 105 * i / height)
b = int(200 + 55 * i / height)
draw.line([(0, i), (width, i)], fill=(r, g, b))
# Add shapes based on prompt keywords
prompt_lower = prompt.lower()
if any(word in prompt_lower for word in ['sun', 'bright', 'light']):
draw.ellipse([width//3, height//3, 2*width//3, 2*height//3],
fill=(255, 255, 0), outline=(255, 200, 0))
if any(word in prompt_lower for word in ['tree', 'nature']):
draw.rectangle([width//2-15, height//2, width//2+15, height-50],
fill=(101, 67, 33))
for i in range(5):
y_offset = i * 30
draw.ellipse([width//2-60, height//2-100+y_offset,
width//2+60, height//2-40+y_offset],
fill=(34, 139, 34))
if any(word in prompt_lower for word in ['water', 'ocean', 'river']):
for i in range(0, width, 40):
draw.arc([i, height-80, i+80, height], 0, 180,
fill=(64, 164, 223), width=3)
# Try to add text
try:
# Use default font
font_size = min(width // 25, 20)
try:
font = ImageFont.truetype("arial.ttf", font_size)
except:
font = ImageFont.load_default()
# Truncate prompt for display
display_text = prompt[:50] + "..." if len(prompt) > 50 else prompt
text = f"STANLEY AI: {display_text}"
# Calculate text position
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
x = (width - text_width) // 2
y = 20
# Add text background
draw.rectangle([x-10, y-5, x+text_width+10, y+text_height+5],
fill=(0, 0, 0, 180))
# Add text
draw.text((x, y), text, fill=(255, 255, 255), font=font)
except Exception as font_error:
logger.warning(f"Could not add text: {font_error}")
# Convert to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG", optimize=True)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
except Exception as e:
logger.error(f"Image generation error: {e}")
# Ultimate fallback - solid color
img = Image.new('RGB', (width, height),
color=(random.randint(50, 200),
random.randint(50, 200),
random.randint(50, 200)))
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
# ============================================================================
# FLASK ROUTES
# ============================================================================
@app.route('/')
def home():
return jsonify({
"message": "🚀 STANLEY AI API is running!",
"version": "3.0",
"status": "active",
"model": current_model_name or "Loading...",
"optimized": "true",
"cache_size": len(response_cache),
"endpoints": [
"/api/chat - Main chat endpoint",
"/api/chat-fast - Faster responses",
"/api/generate-image - Simple image generation",
"/api/health - System health check",
"/api/cache/clear - Clear response cache"
]
})
@app.route('/api/health')
def health_check():
return jsonify({
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"model": current_model_name,
"cache_entries": len(response_cache),
"timestamp": time.time()
})
@app.route('/api/chat', methods=['POST'])
def chat():
try:
start_time = time.time()
data = request.get_json()
user_message = data.get('message', '')
if not user_message:
return jsonify({"error": "Tafadhali provide a message"}), 400
logger.info(f"💬 Processing: {user_message[:60]}...")
# Generate response
response = generate_response(user_message)
response_time = round(time.time() - start_time, 2)
# Check if response contains Kiswahili
has_kiswahili = detect_kiswahili_context(response)
return jsonify({
"response": response,
"status": "success",
"response_time": f"{response_time}s",
"model": current_model_name,
"cultural_context": has_kiswahili,
"language": "en+sw" if has_kiswahili else "en",
"word_count": len(response.split())
})
except Exception as e:
logger.error(f"Chat error: {e}")
return jsonify({
"error": f"Pole! Error: {str(e)[:100]}",
"status": "error"
}), 500
@app.route('/api/chat-fast', methods=['POST'])
def chat_fast():
"""Faster endpoint with shorter responses"""
try:
data = request.get_json()
user_message = data.get('message', '')
if not user_message:
return jsonify({"error": "Please provide a message"}), 400
# Quick response with fewer tokens
response = generate_response(user_message, max_tokens=256)
return jsonify({
"response": response,
"status": "success",
"model": f"{current_model_name} (fast mode)",
"response_type": "concise"
})
except Exception as e:
return jsonify({"error": "Quick response failed"}), 500
@app.route('/api/generate-image', methods=['POST'])
def generate_image_endpoint():
"""Simple image generation endpoint"""
try:
data = request.get_json()
prompt = data.get('prompt', 'A beautiful landscape')
width = min(data.get('width', 512), 1024)
height = min(data.get('height', 512), 1024)
logger.info(f"🎨 Generating image: {prompt[:40]}...")
image_data = generate_image_simple(prompt, width, height)
if image_data:
return jsonify({
"image": image_data,
"prompt": prompt,
"status": "success",
"method": "PIL generated",
"dimensions": f"{width}x{height}"
})
else:
return jsonify({"error": "Could not generate image"}), 500
except Exception as e:
return jsonify({"error": f"Image error: {str(e)[:80]}"}), 500
@app.route('/api/cache/clear', methods=['POST'])
def clear_cache():
"""Clear response cache"""
cache_size = len(response_cache)
response_cache.clear()
# Clear GPU cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
return jsonify({
"status": "success",
"cleared_entries": cache_size,
"message": "Cache cleared"
})
@app.route('/api/switch-model', methods=['POST'])
def switch_model():
"""Switch between available models"""
try:
data = request.get_json()
model_choice = data.get('model', 'primary')
model_name = MODEL_CONFIG.get(model_choice, MODEL_CONFIG["primary"])
success = load_model_optimized(model_name)
if success:
return jsonify({
"status": "success",
"message": f"Switched to {model_name}",
"current_model": current_model_name
})
else:
return jsonify({"error": "Failed to switch model"}), 500
except Exception as e:
return jsonify({"error": str(e)}), 500
# ============================================================================
# INITIALIZATION & STARTUP
# ============================================================================
def initialize_app():
"""Initialize the application"""
logger.info("🚀 Initializing STANLEY AI...")
# Load model in background thread
def load_model_background():
load_model_optimized()
background_thread = Thread(target=load_model_background, daemon=True)
background_thread.start()
logger.info("✅ STANLEY AI initialized and ready!")
# Initialize on import
initialize_app()
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(debug=False, host='0.0.0.0', port=port, threaded=True)