#!/usr/bin/env python3 """ NEED AI - Production Flask API (FIXED VERSION) Models are checked at runtime, not at import time """ from flask import Flask, request, jsonify from flask_cors import CORS import torch import logging import os import time logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) # Set username as environment variable os.environ['HF_USERNAME'] = 'yogami9' HF_USERNAME = 'yogami9' DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"🖥️ Device: {DEVICE}") logger.info(f"👤 HF Username: {HF_USERNAME}") class ModelCache: """Lazy loading model cache""" def __init__(self): self.models = {} self.tokenizers = {} self.load_attempts = {} logger.info("📦 Model cache initialized") def _load_model(self, model_key, model_id, loader_func): """Generic model loader with error handling""" if model_key in self.models: return self.models[model_key], self.tokenizers.get(model_key) # Track load attempts to avoid infinite retries if self.load_attempts.get(model_key, 0) > 3: raise Exception(f"Max load attempts exceeded for {model_key}") self.load_attempts[model_key] = self.load_attempts.get(model_key, 0) + 1 try: logger.info(f"📥 Loading {model_key} from {model_id}...") result = loader_func(model_id) logger.info(f"✅ {model_key} loaded successfully") return result except Exception as e: logger.error(f"❌ Failed to load {model_key}: {str(e)[:200]}") raise def get_category_model(self): def loader(model_id): from transformers import T5ForConditionalGeneration, T5Tokenizer model = T5ForConditionalGeneration.from_pretrained(model_id).to(DEVICE) tokenizer = T5Tokenizer.from_pretrained(model_id) self.models['category'] = model self.tokenizers['category'] = tokenizer return model, tokenizer return self._load_model('category', f'{HF_USERNAME}/need-category-recommendation', loader) def get_chat_model(self): def loader(model_id): from transformers import T5ForConditionalGeneration, T5Tokenizer model = T5ForConditionalGeneration.from_pretrained(model_id).to(DEVICE) tokenizer = T5Tokenizer.from_pretrained(model_id) self.models['chat'] = model self.tokenizers['chat'] = tokenizer return model, tokenizer return self._load_model('chat', f'{HF_USERNAME}/need-chat-support', loader) def get_service_model(self): def loader(model_id): from transformers import T5ForConditionalGeneration, T5Tokenizer model = T5ForConditionalGeneration.from_pretrained(model_id).to(DEVICE) tokenizer = T5Tokenizer.from_pretrained(model_id) self.models['service'] = model self.tokenizers['service'] = tokenizer return model, tokenizer return self._load_model('service', f'{HF_USERNAME}/need-service-description', loader) def get_search_model(self): def loader(model_id): from sentence_transformers import SentenceTransformer model = SentenceTransformer(model_id) self.models['search'] = model return model, None return self._load_model('search', f'{HF_USERNAME}/need-semantic-search', loader) def get_moderation_model(self): def loader(model_id): from transformers import AutoModelForSequenceClassification, AutoTokenizer model = AutoModelForSequenceClassification.from_pretrained(model_id).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_id) self.models['moderation'] = model self.tokenizers['moderation'] = tokenizer return model, tokenizer return self._load_model('moderation', f'{HF_USERNAME}/need-content-moderation', loader) # Initialize cache model_cache = ModelCache() @app.route('/', methods=['GET']) def home(): return jsonify({ 'name': 'NEED AI API', 'version': '2.0.1', 'status': 'running', 'username': HF_USERNAME, 'models_loaded': len(model_cache.models), 'endpoints': { 'health': '/health', 'category': '/api/category', 'chat': '/api/chat', 'service': '/api/service', 'search': '/api/search', 'moderate': '/api/moderate' }, 'documentation': 'https://github.com/Need-Service-App/need-ai-model', 'note': 'First request per model takes 30-60 seconds (model download)' }) @app.route('/health', methods=['GET']) def health(): return jsonify({ 'status': 'healthy', 'device': str(DEVICE), 'gpu_available': torch.cuda.is_available(), 'models_loaded': len(model_cache.models), 'models_cached': list(model_cache.models.keys()), 'username': HF_USERNAME, 'note': 'Models load on first use' }) @app.route('/api/category', methods=['POST']) def predict_category(): try: start = time.time() data = request.get_json() if not data or 'query' not in data: return jsonify({'error': 'Missing "query" in request body'}), 400 query = data['query'] logger.info(f"📥 Category request: {query[:50]}") # Load model on demand model, tokenizer = model_cache.get_category_model() input_text = f"categorize: {query}" input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model.generate(input_ids, max_length=32, num_beams=4) category = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.time() - start logger.info(f"✅ Category: {category} ({elapsed:.2f}s)") return jsonify({ 'query': query, 'category': category, 'inference_time': f'{elapsed:.3f}s', 'model': 'category-recommendation' }) except Exception as e: logger.error(f"❌ Category error: {str(e)}") return jsonify({ 'error': 'Inference failed', 'message': str(e)[:300], 'suggestion': 'Check logs. Model may be loading (wait 30s) or not accessible.' }), 500 @app.route('/api/chat', methods=['POST']) def answer_question(): try: start = time.time() data = request.get_json() if not data or 'question' not in data: return jsonify({'error': 'Missing "question" in request body'}), 400 question = data['question'] logger.info(f"📥 Chat request: {question[:50]}") model, tokenizer = model_cache.get_chat_model() input_text = f"answer question: {question}" input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model.generate(input_ids, max_length=256, num_beams=4) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.time() - start logger.info(f"✅ Chat answer ({elapsed:.2f}s)") return jsonify({ 'question': question, 'answer': answer, 'inference_time': f'{elapsed:.3f}s', 'model': 'chat-support' }) except Exception as e: logger.error(f"❌ Chat error: {str(e)}") return jsonify({ 'error': 'Inference failed', 'message': str(e)[:300] }), 500 @app.route('/api/service', methods=['POST']) def generate_description(): try: start = time.time() data = request.get_json() if not data or 'service_info' not in data: return jsonify({'error': 'Missing "service_info" in request body'}), 400 service_info = data['service_info'] logger.info(f"📥 Service description request") model, tokenizer = model_cache.get_service_model() input_text = f"generate professional description: {service_info}" input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model.generate(input_ids, max_length=512, num_beams=4) description = tokenizer.decode(outputs[0], skip_special_tokens=True) elapsed = time.time() - start logger.info(f"✅ Service description ({elapsed:.2f}s)") return jsonify({ 'service_info': service_info, 'description': description, 'inference_time': f'{elapsed:.3f}s', 'model': 'service-description' }) except Exception as e: logger.error(f"❌ Service error: {str(e)}") return jsonify({ 'error': 'Inference failed', 'message': str(e)[:300] }), 500 @app.route('/api/search', methods=['POST']) def semantic_search(): try: start = time.time() data = request.get_json() if not data or 'query' not in data or 'documents' not in data: return jsonify({'error': 'Missing "query" or "documents"'}), 400 query = data['query'] documents = data['documents'] if not isinstance(documents, list): return jsonify({'error': '"documents" must be a list'}), 400 logger.info(f"📥 Search request: {query[:50]}") model, _ = model_cache.get_search_model() query_embedding = model.encode([query]) doc_embeddings = model.encode(documents) from sklearn.metrics.pairwise import cosine_similarity similarities = cosine_similarity(query_embedding, doc_embeddings)[0] results = [ {'document': doc, 'similarity': float(score), 'rank': i + 1} for i, (doc, score) in enumerate( sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True) ) ] elapsed = time.time() - start logger.info(f"✅ Search complete ({elapsed:.2f}s)") return jsonify({ 'query': query, 'results': results, 'inference_time': f'{elapsed:.3f}s', 'model': 'semantic-search' }) except Exception as e: logger.error(f"❌ Search error: {str(e)}") return jsonify({ 'error': 'Inference failed', 'message': str(e)[:300] }), 500 @app.route('/api/moderate', methods=['POST']) def moderate_content(): try: start = time.time() data = request.get_json() if not data or 'text' not in data: return jsonify({'error': 'Missing "text" in request body'}), 400 text = data['text'] logger.info(f"📥 Moderation request") model, tokenizer = model_cache.get_moderation_model() inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): import torch.nn.functional as F outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1) toxic_prob = probs[0][1].item() is_toxic = toxic_prob > 0.5 elapsed = time.time() - start logger.info(f"✅ Moderation: {'toxic' if is_toxic else 'safe'} ({elapsed:.2f}s)") return jsonify({ 'text': text, 'is_toxic': is_toxic, 'toxicity_score': round(toxic_prob, 4), 'status': 'toxic' if is_toxic else 'safe', 'inference_time': f'{elapsed:.3f}s', 'model': 'content-moderation' }) except Exception as e: logger.error(f"❌ Moderation error: {str(e)}") return jsonify({ 'error': 'Inference failed', 'message': str(e)[:300] }), 500 @app.errorhandler(404) def not_found(error): return jsonify({'error': 'Endpoint not found', 'available': ['/', '/health', '/api/category', '/api/chat', '/api/service', '/api/search', '/api/moderate']}), 404 @app.errorhandler(500) def internal_error(error): return jsonify({'error': 'Internal server error'}), 500 if __name__ == '__main__': port = int(os.getenv('PORT', 7860)) logger.info(f"🚀 Starting NEED AI API on port {port}") logger.info(f"👤 Username: {HF_USERNAME}") logger.info(f"💡 Models will load on first request (30-60s each)") app.run(host='0.0.0.0', port=port, debug=False)