need-ai-api / app.py
yogami9's picture
Fix: Update app.py - models load on demand
54c5eeb verified
#!/usr/bin/env python3
"""
NEED AI - Production Flask API (FIXED VERSION)
Models are checked at runtime, not at import time
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import logging
import os
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# Set username as environment variable
os.environ['HF_USERNAME'] = 'yogami9'
HF_USERNAME = 'yogami9'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"πŸ–₯️ Device: {DEVICE}")
logger.info(f"πŸ‘€ HF Username: {HF_USERNAME}")
class ModelCache:
"""Lazy loading model cache"""
def __init__(self):
self.models = {}
self.tokenizers = {}
self.load_attempts = {}
logger.info("πŸ“¦ Model cache initialized")
def _load_model(self, model_key, model_id, loader_func):
"""Generic model loader with error handling"""
if model_key in self.models:
return self.models[model_key], self.tokenizers.get(model_key)
# Track load attempts to avoid infinite retries
if self.load_attempts.get(model_key, 0) > 3:
raise Exception(f"Max load attempts exceeded for {model_key}")
self.load_attempts[model_key] = self.load_attempts.get(model_key, 0) + 1
try:
logger.info(f"πŸ“₯ Loading {model_key} from {model_id}...")
result = loader_func(model_id)
logger.info(f"βœ… {model_key} loaded successfully")
return result
except Exception as e:
logger.error(f"❌ Failed to load {model_key}: {str(e)[:200]}")
raise
def get_category_model(self):
def loader(model_id):
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
tokenizer = T5Tokenizer.from_pretrained(model_id)
self.models['category'] = model
self.tokenizers['category'] = tokenizer
return model, tokenizer
return self._load_model('category', f'{HF_USERNAME}/need-category-recommendation', loader)
def get_chat_model(self):
def loader(model_id):
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
tokenizer = T5Tokenizer.from_pretrained(model_id)
self.models['chat'] = model
self.tokenizers['chat'] = tokenizer
return model, tokenizer
return self._load_model('chat', f'{HF_USERNAME}/need-chat-support', loader)
def get_service_model(self):
def loader(model_id):
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_id).to(DEVICE)
tokenizer = T5Tokenizer.from_pretrained(model_id)
self.models['service'] = model
self.tokenizers['service'] = tokenizer
return model, tokenizer
return self._load_model('service', f'{HF_USERNAME}/need-service-description', loader)
def get_search_model(self):
def loader(model_id):
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_id)
self.models['search'] = model
return model, None
return self._load_model('search', f'{HF_USERNAME}/need-semantic-search', loader)
def get_moderation_model(self):
def loader(model_id):
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained(model_id).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_id)
self.models['moderation'] = model
self.tokenizers['moderation'] = tokenizer
return model, tokenizer
return self._load_model('moderation', f'{HF_USERNAME}/need-content-moderation', loader)
# Initialize cache
model_cache = ModelCache()
@app.route('/', methods=['GET'])
def home():
return jsonify({
'name': 'NEED AI API',
'version': '2.0.1',
'status': 'running',
'username': HF_USERNAME,
'models_loaded': len(model_cache.models),
'endpoints': {
'health': '/health',
'category': '/api/category',
'chat': '/api/chat',
'service': '/api/service',
'search': '/api/search',
'moderate': '/api/moderate'
},
'documentation': 'https://github.com/Need-Service-App/need-ai-model',
'note': 'First request per model takes 30-60 seconds (model download)'
})
@app.route('/health', methods=['GET'])
def health():
return jsonify({
'status': 'healthy',
'device': str(DEVICE),
'gpu_available': torch.cuda.is_available(),
'models_loaded': len(model_cache.models),
'models_cached': list(model_cache.models.keys()),
'username': HF_USERNAME,
'note': 'Models load on first use'
})
@app.route('/api/category', methods=['POST'])
def predict_category():
try:
start = time.time()
data = request.get_json()
if not data or 'query' not in data:
return jsonify({'error': 'Missing "query" in request body'}), 400
query = data['query']
logger.info(f"πŸ“₯ Category request: {query[:50]}")
# Load model on demand
model, tokenizer = model_cache.get_category_model()
input_text = f"categorize: {query}"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(input_ids, max_length=32, num_beams=4)
category = tokenizer.decode(outputs[0], skip_special_tokens=True)
elapsed = time.time() - start
logger.info(f"βœ… Category: {category} ({elapsed:.2f}s)")
return jsonify({
'query': query,
'category': category,
'inference_time': f'{elapsed:.3f}s',
'model': 'category-recommendation'
})
except Exception as e:
logger.error(f"❌ Category error: {str(e)}")
return jsonify({
'error': 'Inference failed',
'message': str(e)[:300],
'suggestion': 'Check logs. Model may be loading (wait 30s) or not accessible.'
}), 500
@app.route('/api/chat', methods=['POST'])
def answer_question():
try:
start = time.time()
data = request.get_json()
if not data or 'question' not in data:
return jsonify({'error': 'Missing "question" in request body'}), 400
question = data['question']
logger.info(f"πŸ“₯ Chat request: {question[:50]}")
model, tokenizer = model_cache.get_chat_model()
input_text = f"answer question: {question}"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(input_ids, max_length=256, num_beams=4)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
elapsed = time.time() - start
logger.info(f"βœ… Chat answer ({elapsed:.2f}s)")
return jsonify({
'question': question,
'answer': answer,
'inference_time': f'{elapsed:.3f}s',
'model': 'chat-support'
})
except Exception as e:
logger.error(f"❌ Chat error: {str(e)}")
return jsonify({
'error': 'Inference failed',
'message': str(e)[:300]
}), 500
@app.route('/api/service', methods=['POST'])
def generate_description():
try:
start = time.time()
data = request.get_json()
if not data or 'service_info' not in data:
return jsonify({'error': 'Missing "service_info" in request body'}), 400
service_info = data['service_info']
logger.info(f"πŸ“₯ Service description request")
model, tokenizer = model_cache.get_service_model()
input_text = f"generate professional description: {service_info}"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(input_ids, max_length=512, num_beams=4)
description = tokenizer.decode(outputs[0], skip_special_tokens=True)
elapsed = time.time() - start
logger.info(f"βœ… Service description ({elapsed:.2f}s)")
return jsonify({
'service_info': service_info,
'description': description,
'inference_time': f'{elapsed:.3f}s',
'model': 'service-description'
})
except Exception as e:
logger.error(f"❌ Service error: {str(e)}")
return jsonify({
'error': 'Inference failed',
'message': str(e)[:300]
}), 500
@app.route('/api/search', methods=['POST'])
def semantic_search():
try:
start = time.time()
data = request.get_json()
if not data or 'query' not in data or 'documents' not in data:
return jsonify({'error': 'Missing "query" or "documents"'}), 400
query = data['query']
documents = data['documents']
if not isinstance(documents, list):
return jsonify({'error': '"documents" must be a list'}), 400
logger.info(f"πŸ“₯ Search request: {query[:50]}")
model, _ = model_cache.get_search_model()
query_embedding = model.encode([query])
doc_embeddings = model.encode(documents)
from sklearn.metrics.pairwise import cosine_similarity
similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
results = [
{'document': doc, 'similarity': float(score), 'rank': i + 1}
for i, (doc, score) in enumerate(
sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)
)
]
elapsed = time.time() - start
logger.info(f"βœ… Search complete ({elapsed:.2f}s)")
return jsonify({
'query': query,
'results': results,
'inference_time': f'{elapsed:.3f}s',
'model': 'semantic-search'
})
except Exception as e:
logger.error(f"❌ Search error: {str(e)}")
return jsonify({
'error': 'Inference failed',
'message': str(e)[:300]
}), 500
@app.route('/api/moderate', methods=['POST'])
def moderate_content():
try:
start = time.time()
data = request.get_json()
if not data or 'text' not in data:
return jsonify({'error': 'Missing "text" in request body'}), 400
text = data['text']
logger.info(f"πŸ“₯ Moderation request")
model, tokenizer = model_cache.get_moderation_model()
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
import torch.nn.functional as F
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
toxic_prob = probs[0][1].item()
is_toxic = toxic_prob > 0.5
elapsed = time.time() - start
logger.info(f"βœ… Moderation: {'toxic' if is_toxic else 'safe'} ({elapsed:.2f}s)")
return jsonify({
'text': text,
'is_toxic': is_toxic,
'toxicity_score': round(toxic_prob, 4),
'status': 'toxic' if is_toxic else 'safe',
'inference_time': f'{elapsed:.3f}s',
'model': 'content-moderation'
})
except Exception as e:
logger.error(f"❌ Moderation error: {str(e)}")
return jsonify({
'error': 'Inference failed',
'message': str(e)[:300]
}), 500
@app.errorhandler(404)
def not_found(error):
return jsonify({'error': 'Endpoint not found', 'available': ['/', '/health', '/api/category', '/api/chat', '/api/service', '/api/search', '/api/moderate']}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({'error': 'Internal server error'}), 500
if __name__ == '__main__':
port = int(os.getenv('PORT', 7860))
logger.info(f"πŸš€ Starting NEED AI API on port {port}")
logger.info(f"πŸ‘€ Username: {HF_USERNAME}")
logger.info(f"πŸ’‘ Models will load on first request (30-60s each)")
app.run(host='0.0.0.0', port=port, debug=False)