|
|
""" |
|
|
STANLEY AI - Optimized Flask Backend |
|
|
Deploy on Hugging Face Spaces with fast, smaller models |
|
|
""" |
|
|
|
|
|
from flask import Flask, request, jsonify, send_file |
|
|
from flask_cors import CORS |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
import torch |
|
|
import time |
|
|
import re |
|
|
import logging |
|
|
from threading import Thread |
|
|
import queue |
|
|
import io |
|
|
import base64 |
|
|
import random |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import os |
|
|
import gc |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG = { |
|
|
"primary": "Qwen/Qwen2.5-1.8B-Instruct", |
|
|
"fallback": "microsoft/Phi-3-mini-4k-instruct", |
|
|
"tiny": "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
} |
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
model_loaded = False |
|
|
current_model_name = None |
|
|
|
|
|
|
|
|
response_cache = {} |
|
|
CACHE_SIZE = 200 |
|
|
|
|
|
|
|
|
STANLEY_AI_SYSTEM = """You are STANLEY AI - an advanced assistant with Kiswahili cultural knowledge. |
|
|
Provide helpful, concise responses. Integrate Kiswahili phrases naturally when relevant. |
|
|
|
|
|
Key capabilities: |
|
|
- Answer questions knowledgeably |
|
|
- Use Kiswahili for greetings, proverbs, and cultural references |
|
|
- Explain concepts clearly |
|
|
- Be efficient and to the point |
|
|
|
|
|
Format: Use **bold** for emphasis. Keep responses under 300 words unless detailed explanation is needed.""" |
|
|
|
|
|
|
|
|
KISWAHILI_KNOWLEDGE = { |
|
|
"greetings": { |
|
|
"hello": "Jambo / Habari", |
|
|
"how_are_you": "Habari yako?", |
|
|
"goodbye": "Kwaheri / Tuonane tena", |
|
|
"thank_you": "Asante sana", |
|
|
"welcome": "Karibu / Karibuni" |
|
|
}, |
|
|
"proverbs": [ |
|
|
"Mwenye pupa hadiriki kula tamu - The impatient one misses sweet things.", |
|
|
"Asiyefunzwa na mamae hufunzwa na ulimwengu - He who is not taught by his mother is taught by the world.", |
|
|
"Haraka haraka haina baraka - Hurry hurry has no blessing.", |
|
|
"Ukitaka kwenda haraka, nenda peke yako. Ukitaka kwenda mbali, nenda na wenzako - If you want to go fast, go alone. If you want to go far, go together." |
|
|
], |
|
|
"lion_king": { |
|
|
"simba": "Lion (the main character)", |
|
|
"rafiki": "Friend (the wise baboon)", |
|
|
"hakuna_matata": "No worries / No problems", |
|
|
"mufasa": "Simba's father, the king", |
|
|
"nala": "Simba's childhood friend and queen" |
|
|
} |
|
|
} |
|
|
|
|
|
def load_model_optimized(model_name=None): |
|
|
"""Load model with optimizations for Hugging Face Spaces""" |
|
|
global model, tokenizer, model_loaded, current_model_name |
|
|
|
|
|
if model_loaded and model_name == current_model_name: |
|
|
return |
|
|
|
|
|
|
|
|
if not model_name: |
|
|
model_name = MODEL_CONFIG["primary"] |
|
|
|
|
|
logger.info(f"🚀 Loading model: {model_name}") |
|
|
|
|
|
try: |
|
|
|
|
|
if model is not None: |
|
|
del model |
|
|
del tokenizer |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
load_in_4bit=True, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
model_loaded = True |
|
|
current_model_name = model_name |
|
|
|
|
|
|
|
|
prewarm_model() |
|
|
|
|
|
logger.info(f"✅ Model loaded successfully: {model_name}") |
|
|
logger.info(f"📊 Model device: {model.device}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error loading model: {e}") |
|
|
|
|
|
|
|
|
if model_name != MODEL_CONFIG["fallback"]: |
|
|
logger.info("🔄 Trying fallback model...") |
|
|
return load_model_optimized(MODEL_CONFIG["fallback"]) |
|
|
else: |
|
|
logger.error("❌ All models failed to load") |
|
|
model_loaded = False |
|
|
return False |
|
|
|
|
|
def prewarm_model(): |
|
|
"""Generate a dummy response to warm up the model""" |
|
|
try: |
|
|
dummy_input = "Hello, STANLEY AI!" |
|
|
messages = [ |
|
|
{"role": "system", "content": "Say hello briefly."}, |
|
|
{"role": "user", "content": dummy_input} |
|
|
] |
|
|
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
_ = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=10, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
logger.info("✅ Model pre-warmed successfully!") |
|
|
except Exception as e: |
|
|
logger.warning(f"Pre-warm failed: {e}") |
|
|
|
|
|
def detect_kiswahili_context(text): |
|
|
"""Detect if text contains Kiswahili or cultural references""" |
|
|
if not text: |
|
|
return False |
|
|
|
|
|
text_lower = text.lower() |
|
|
kiswahili_keywords = [ |
|
|
'swahili', 'kiswahili', 'hakuna', 'matata', 'asante', 'rafiki', |
|
|
'jambo', 'mambo', 'pole', 'sawa', 'karibu', 'kwaheri', 'simba', |
|
|
'lion king', 'mufasa', 'nala', 'kenya', 'tanzania', 'africa', |
|
|
'habari', 'nze', 'pumbaa', 'timon', 'safari', 'ujamaa' |
|
|
] |
|
|
|
|
|
return any(keyword in text_lower for keyword in kiswahili_keywords) |
|
|
|
|
|
def enhance_with_kiswahili(response, user_message): |
|
|
"""Add Kiswahili elements to response""" |
|
|
if detect_kiswahili_context(user_message): |
|
|
|
|
|
greetings = list(KISWAHILI_KNOWLEDGE["greetings"].values()) |
|
|
greeting = random.choice(greetings) |
|
|
|
|
|
|
|
|
if any(word in user_message.lower() for word in ['advice', 'wisdom', 'lesson', 'teach']): |
|
|
proverb = random.choice(KISWAHILI_KNOWLEDGE["proverbs"]) |
|
|
enhanced = f"{greeting}! {response}\n\n**🔥 Kiswahili Proverb:** {proverb}" |
|
|
else: |
|
|
enhanced = f"{greeting}! {response}" |
|
|
|
|
|
|
|
|
if any(word in user_message.lower() for word in ['lion', 'simba', 'mufasa', 'disney']): |
|
|
lion_fact = "Did you know? 'Simba' means lion in Kiswahili, and 'Rafiki' means friend!" |
|
|
enhanced += f"\n\n{lion_fact}" |
|
|
|
|
|
return enhanced |
|
|
|
|
|
return response |
|
|
|
|
|
def get_cached_response(user_message): |
|
|
"""Get response from cache""" |
|
|
cache_key = user_message.lower().strip()[:80] |
|
|
return response_cache.get(cache_key) |
|
|
|
|
|
def set_cached_response(user_message, response): |
|
|
"""Cache response""" |
|
|
cache_key = user_message.lower().strip()[:80] |
|
|
if len(response_cache) >= CACHE_SIZE: |
|
|
|
|
|
random_key = random.choice(list(response_cache.keys())) |
|
|
del response_cache[random_key] |
|
|
response_cache[cache_key] = response |
|
|
|
|
|
def generate_response(user_message, max_tokens=512): |
|
|
"""Generate optimized response""" |
|
|
|
|
|
|
|
|
cached = get_cached_response(user_message) |
|
|
if cached: |
|
|
logger.info("📦 Using cached response") |
|
|
return cached |
|
|
|
|
|
|
|
|
if not model_loaded: |
|
|
success = load_model_optimized() |
|
|
if not success: |
|
|
return "I'm still initializing. Please try again in a moment." |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": STANLEY_AI_SYSTEM}, |
|
|
{"role": "user", "content": user_message} |
|
|
] |
|
|
|
|
|
try: |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
top_k=40, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
repetition_penalty=1.1, |
|
|
no_repeat_ngram_size=3, |
|
|
early_stopping=True |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode( |
|
|
outputs[0][inputs['input_ids'].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
|
|
|
enhanced_response = enhance_with_kiswahili(response, user_message) |
|
|
|
|
|
|
|
|
set_cached_response(user_message, enhanced_response) |
|
|
|
|
|
return enhanced_response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation error: {e}") |
|
|
return f"Pole! I encountered an error: {str(e)[:100]}" |
|
|
|
|
|
def generate_image_simple(prompt, width=512, height=512): |
|
|
"""Simple image generation using PIL (no external dependencies)""" |
|
|
try: |
|
|
|
|
|
img = Image.new('RGB', (width, height), color='white') |
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
|
|
|
for i in range(height): |
|
|
r = int(100 + 155 * i / height) |
|
|
g = int(150 + 105 * i / height) |
|
|
b = int(200 + 55 * i / height) |
|
|
draw.line([(0, i), (width, i)], fill=(r, g, b)) |
|
|
|
|
|
|
|
|
prompt_lower = prompt.lower() |
|
|
|
|
|
if any(word in prompt_lower for word in ['sun', 'bright', 'light']): |
|
|
draw.ellipse([width//3, height//3, 2*width//3, 2*height//3], |
|
|
fill=(255, 255, 0), outline=(255, 200, 0)) |
|
|
|
|
|
if any(word in prompt_lower for word in ['tree', 'nature']): |
|
|
draw.rectangle([width//2-15, height//2, width//2+15, height-50], |
|
|
fill=(101, 67, 33)) |
|
|
for i in range(5): |
|
|
y_offset = i * 30 |
|
|
draw.ellipse([width//2-60, height//2-100+y_offset, |
|
|
width//2+60, height//2-40+y_offset], |
|
|
fill=(34, 139, 34)) |
|
|
|
|
|
if any(word in prompt_lower for word in ['water', 'ocean', 'river']): |
|
|
for i in range(0, width, 40): |
|
|
draw.arc([i, height-80, i+80, height], 0, 180, |
|
|
fill=(64, 164, 223), width=3) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
font_size = min(width // 25, 20) |
|
|
try: |
|
|
font = ImageFont.truetype("arial.ttf", font_size) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
display_text = prompt[:50] + "..." if len(prompt) > 50 else prompt |
|
|
text = f"STANLEY AI: {display_text}" |
|
|
|
|
|
|
|
|
bbox = draw.textbbox((0, 0), text, font=font) |
|
|
text_width = bbox[2] - bbox[0] |
|
|
text_height = bbox[3] - bbox[1] |
|
|
|
|
|
x = (width - text_width) // 2 |
|
|
y = 20 |
|
|
|
|
|
|
|
|
draw.rectangle([x-10, y-5, x+text_width+10, y+text_height+5], |
|
|
fill=(0, 0, 0, 180)) |
|
|
|
|
|
|
|
|
draw.text((x, y), text, fill=(255, 255, 255), font=font) |
|
|
|
|
|
except Exception as font_error: |
|
|
logger.warning(f"Could not add text: {font_error}") |
|
|
|
|
|
|
|
|
buffered = io.BytesIO() |
|
|
img.save(buffered, format="PNG", optimize=True) |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
return f"data:image/png;base64,{img_str}" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Image generation error: {e}") |
|
|
|
|
|
img = Image.new('RGB', (width, height), |
|
|
color=(random.randint(50, 200), |
|
|
random.randint(50, 200), |
|
|
random.randint(50, 200))) |
|
|
buffered = io.BytesIO() |
|
|
img.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
return f"data:image/png;base64,{img_str}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/') |
|
|
def home(): |
|
|
return jsonify({ |
|
|
"message": "🚀 STANLEY AI API is running!", |
|
|
"version": "3.0", |
|
|
"status": "active", |
|
|
"model": current_model_name or "Loading...", |
|
|
"optimized": "true", |
|
|
"cache_size": len(response_cache), |
|
|
"endpoints": [ |
|
|
"/api/chat - Main chat endpoint", |
|
|
"/api/chat-fast - Faster responses", |
|
|
"/api/generate-image - Simple image generation", |
|
|
"/api/health - System health check", |
|
|
"/api/cache/clear - Clear response cache" |
|
|
] |
|
|
}) |
|
|
|
|
|
@app.route('/api/health') |
|
|
def health_check(): |
|
|
return jsonify({ |
|
|
"status": "healthy" if model_loaded else "loading", |
|
|
"model_loaded": model_loaded, |
|
|
"model": current_model_name, |
|
|
"cache_entries": len(response_cache), |
|
|
"timestamp": time.time() |
|
|
}) |
|
|
|
|
|
@app.route('/api/chat', methods=['POST']) |
|
|
def chat(): |
|
|
try: |
|
|
start_time = time.time() |
|
|
data = request.get_json() |
|
|
user_message = data.get('message', '') |
|
|
|
|
|
if not user_message: |
|
|
return jsonify({"error": "Tafadhali provide a message"}), 400 |
|
|
|
|
|
logger.info(f"💬 Processing: {user_message[:60]}...") |
|
|
|
|
|
|
|
|
response = generate_response(user_message) |
|
|
response_time = round(time.time() - start_time, 2) |
|
|
|
|
|
|
|
|
has_kiswahili = detect_kiswahili_context(response) |
|
|
|
|
|
return jsonify({ |
|
|
"response": response, |
|
|
"status": "success", |
|
|
"response_time": f"{response_time}s", |
|
|
"model": current_model_name, |
|
|
"cultural_context": has_kiswahili, |
|
|
"language": "en+sw" if has_kiswahili else "en", |
|
|
"word_count": len(response.split()) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Chat error: {e}") |
|
|
return jsonify({ |
|
|
"error": f"Pole! Error: {str(e)[:100]}", |
|
|
"status": "error" |
|
|
}), 500 |
|
|
|
|
|
@app.route('/api/chat-fast', methods=['POST']) |
|
|
def chat_fast(): |
|
|
"""Faster endpoint with shorter responses""" |
|
|
try: |
|
|
data = request.get_json() |
|
|
user_message = data.get('message', '') |
|
|
|
|
|
if not user_message: |
|
|
return jsonify({"error": "Please provide a message"}), 400 |
|
|
|
|
|
|
|
|
response = generate_response(user_message, max_tokens=256) |
|
|
|
|
|
return jsonify({ |
|
|
"response": response, |
|
|
"status": "success", |
|
|
"model": f"{current_model_name} (fast mode)", |
|
|
"response_type": "concise" |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({"error": "Quick response failed"}), 500 |
|
|
|
|
|
@app.route('/api/generate-image', methods=['POST']) |
|
|
def generate_image_endpoint(): |
|
|
"""Simple image generation endpoint""" |
|
|
try: |
|
|
data = request.get_json() |
|
|
prompt = data.get('prompt', 'A beautiful landscape') |
|
|
width = min(data.get('width', 512), 1024) |
|
|
height = min(data.get('height', 512), 1024) |
|
|
|
|
|
logger.info(f"🎨 Generating image: {prompt[:40]}...") |
|
|
|
|
|
image_data = generate_image_simple(prompt, width, height) |
|
|
|
|
|
if image_data: |
|
|
return jsonify({ |
|
|
"image": image_data, |
|
|
"prompt": prompt, |
|
|
"status": "success", |
|
|
"method": "PIL generated", |
|
|
"dimensions": f"{width}x{height}" |
|
|
}) |
|
|
else: |
|
|
return jsonify({"error": "Could not generate image"}), 500 |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({"error": f"Image error: {str(e)[:80]}"}), 500 |
|
|
|
|
|
@app.route('/api/cache/clear', methods=['POST']) |
|
|
def clear_cache(): |
|
|
"""Clear response cache""" |
|
|
cache_size = len(response_cache) |
|
|
response_cache.clear() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return jsonify({ |
|
|
"status": "success", |
|
|
"cleared_entries": cache_size, |
|
|
"message": "Cache cleared" |
|
|
}) |
|
|
|
|
|
@app.route('/api/switch-model', methods=['POST']) |
|
|
def switch_model(): |
|
|
"""Switch between available models""" |
|
|
try: |
|
|
data = request.get_json() |
|
|
model_choice = data.get('model', 'primary') |
|
|
|
|
|
model_name = MODEL_CONFIG.get(model_choice, MODEL_CONFIG["primary"]) |
|
|
|
|
|
success = load_model_optimized(model_name) |
|
|
|
|
|
if success: |
|
|
return jsonify({ |
|
|
"status": "success", |
|
|
"message": f"Switched to {model_name}", |
|
|
"current_model": current_model_name |
|
|
}) |
|
|
else: |
|
|
return jsonify({"error": "Failed to switch model"}), 500 |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_app(): |
|
|
"""Initialize the application""" |
|
|
logger.info("🚀 Initializing STANLEY AI...") |
|
|
|
|
|
|
|
|
def load_model_background(): |
|
|
load_model_optimized() |
|
|
|
|
|
background_thread = Thread(target=load_model_background, daemon=True) |
|
|
background_thread.start() |
|
|
|
|
|
logger.info("✅ STANLEY AI initialized and ready!") |
|
|
|
|
|
|
|
|
initialize_app() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
port = int(os.environ.get('PORT', 7860)) |
|
|
app.run(debug=False, host='0.0.0.0', port=port, threaded=True) |