yogami9 commited on
Commit
36b4278
·
verified ·
1 Parent(s): 32a0420

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ NEED AI - Production Flask API with Direct Model Loading
4
+ """
5
+
6
+ from flask import Flask, request, jsonify
7
+ from flask_cors import CORS
8
+ from transformers import (
9
+ T5ForConditionalGeneration,
10
+ T5Tokenizer,
11
+ AutoModelForSequenceClassification,
12
+ AutoTokenizer
13
+ )
14
+ from sentence_transformers import SentenceTransformer
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
+ import logging
19
+ import os
20
+ from functools import lru_cache
21
+ import time
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ app = Flask(__name__)
27
+ CORS(app)
28
+
29
+ HF_USERNAME = os.getenv("HF_USERNAME", "yogami9")
30
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
+ logger.info(f"Using device: {DEVICE}")
32
+
33
+ class ModelCache:
34
+ def __init__(self):
35
+ self.models = {}
36
+ self.tokenizers = {}
37
+ logger.info("Model cache initialized")
38
+
39
+ @lru_cache(maxsize=1)
40
+ def get_category_model(self):
41
+ if 'category' not in self.models:
42
+ logger.info("Loading Category model...")
43
+ model_name = f"{HF_USERNAME}/need-category-recommendation"
44
+ self.models['category'] = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
45
+ self.tokenizers['category'] = T5Tokenizer.from_pretrained(model_name)
46
+ logger.info("✅ Category model loaded")
47
+ return self.models['category'], self.tokenizers['category']
48
+
49
+ @lru_cache(maxsize=1)
50
+ def get_chat_model(self):
51
+ if 'chat' not in self.models:
52
+ logger.info("Loading Chat model...")
53
+ model_name = f"{HF_USERNAME}/need-chat-support"
54
+ self.models['chat'] = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
55
+ self.tokenizers['chat'] = T5Tokenizer.from_pretrained(model_name)
56
+ logger.info("✅ Chat model loaded")
57
+ return self.models['chat'], self.tokenizers['chat']
58
+
59
+ @lru_cache(maxsize=1)
60
+ def get_service_model(self):
61
+ if 'service' not in self.models:
62
+ logger.info("Loading Service model...")
63
+ model_name = f"{HF_USERNAME}/need-service-description"
64
+ self.models['service'] = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
65
+ self.tokenizers['service'] = T5Tokenizer.from_pretrained(model_name)
66
+ logger.info("✅ Service model loaded")
67
+ return self.models['service'], self.tokenizers['service']
68
+
69
+ @lru_cache(maxsize=1)
70
+ def get_search_model(self):
71
+ if 'search' not in self.models:
72
+ logger.info("Loading Search model...")
73
+ model_name = f"{HF_USERNAME}/need-semantic-search"
74
+ self.models['search'] = SentenceTransformer(model_name)
75
+ logger.info("✅ Search model loaded")
76
+ return self.models['search']
77
+
78
+ @lru_cache(maxsize=1)
79
+ def get_moderation_model(self):
80
+ if 'moderation' not in self.models:
81
+ logger.info("Loading Moderation model...")
82
+ model_name = f"{HF_USERNAME}/need-content-moderation"
83
+ self.models['moderation'] = AutoModelForSequenceClassification.from_pretrained(model_name).to(DEVICE)
84
+ self.tokenizers['moderation'] = AutoTokenizer.from_pretrained(model_name)
85
+ logger.info("✅ Moderation model loaded")
86
+ return self.models['moderation'], self.tokenizers['moderation']
87
+
88
+ model_cache = ModelCache()
89
+
90
+ @app.route('/', methods=['GET'])
91
+ def home():
92
+ return jsonify({
93
+ 'name': 'NEED AI API',
94
+ 'version': '1.0.0',
95
+ 'status': 'running',
96
+ 'endpoints': {
97
+ 'health': '/health',
98
+ 'category': '/api/category',
99
+ 'chat': '/api/chat',
100
+ 'service': '/api/service',
101
+ 'search': '/api/search',
102
+ 'moderate': '/api/moderate',
103
+ 'batch': '/api/batch'
104
+ },
105
+ 'documentation': 'https://github.com/Need-Service-App/need-ai-model'
106
+ })
107
+
108
+ @app.route('/health', methods=['GET'])
109
+ def health():
110
+ return jsonify({
111
+ 'status': 'healthy',
112
+ 'device': str(DEVICE),
113
+ 'models_loaded': len(model_cache.models),
114
+ 'gpu_available': torch.cuda.is_available()
115
+ })
116
+
117
+ @app.route('/api/category', methods=['POST'])
118
+ def predict_category():
119
+ try:
120
+ start_time = time.time()
121
+ data = request.get_json()
122
+ if not data or 'query' not in data:
123
+ return jsonify({'error': 'Missing "query" in request body'}), 400
124
+
125
+ query = data['query']
126
+ model, tokenizer = model_cache.get_category_model()
127
+
128
+ input_text = f"categorize: {query}"
129
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
130
+
131
+ with torch.no_grad():
132
+ outputs = model.generate(input_ids, max_length=32, num_beams=4, early_stopping=True)
133
+
134
+ category = tokenizer.decode(outputs[0], skip_special_tokens=True)
135
+ inference_time = time.time() - start_time
136
+
137
+ return jsonify({
138
+ 'query': query,
139
+ 'category': category,
140
+ 'inference_time': f"{inference_time:.3f}s"
141
+ })
142
+ except Exception as e:
143
+ logger.error(f"Error in predict_category: {str(e)}")
144
+ return jsonify({'error': str(e)}), 500
145
+
146
+ @app.route('/api/chat', methods=['POST'])
147
+ def answer_question():
148
+ try:
149
+ start_time = time.time()
150
+ data = request.get_json()
151
+ if not data or 'question' not in data:
152
+ return jsonify({'error': 'Missing "question" in request body'}), 400
153
+
154
+ question = data['question']
155
+ model, tokenizer = model_cache.get_chat_model()
156
+
157
+ input_text = f"answer question: {question}"
158
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
159
+
160
+ with torch.no_grad():
161
+ outputs = model.generate(input_ids, max_length=256, num_beams=4, early_stopping=True)
162
+
163
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
164
+ inference_time = time.time() - start_time
165
+
166
+ return jsonify({
167
+ 'question': question,
168
+ 'answer': answer,
169
+ 'inference_time': f"{inference_time:.3f}s"
170
+ })
171
+ except Exception as e:
172
+ logger.error(f"Error in answer_question: {str(e)}")
173
+ return jsonify({'error': str(e)}), 500
174
+
175
+ @app.route('/api/service', methods=['POST'])
176
+ def generate_description():
177
+ try:
178
+ start_time = time.time()
179
+ data = request.get_json()
180
+ if not data or 'service_info' not in data:
181
+ return jsonify({'error': 'Missing "service_info" in request body'}), 400
182
+
183
+ service_info = data['service_info']
184
+ model, tokenizer = model_cache.get_service_model()
185
+
186
+ input_text = f"generate professional description: {service_info}"
187
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
188
+
189
+ with torch.no_grad():
190
+ outputs = model.generate(input_ids, max_length=512, num_beams=4, early_stopping=True)
191
+
192
+ description = tokenizer.decode(outputs[0], skip_special_tokens=True)
193
+ inference_time = time.time() - start_time
194
+
195
+ return jsonify({
196
+ 'service_info': service_info,
197
+ 'description': description,
198
+ 'inference_time': f"{inference_time:.3f}s"
199
+ })
200
+ except Exception as e:
201
+ logger.error(f"Error in generate_description: {str(e)}")
202
+ return jsonify({'error': str(e)}), 500
203
+
204
+ @app.route('/api/search', methods=['POST'])
205
+ def semantic_search():
206
+ try:
207
+ start_time = time.time()
208
+ data = request.get_json()
209
+ if not data or 'query' not in data or 'documents' not in data:
210
+ return jsonify({'error': 'Missing "query" or "documents" in request body'}), 400
211
+
212
+ query = data['query']
213
+ documents = data['documents']
214
+
215
+ if not isinstance(documents, list):
216
+ return jsonify({'error': '"documents" must be a list'}), 400
217
+
218
+ model = model_cache.get_search_model()
219
+
220
+ query_embedding = model.encode([query])
221
+ doc_embeddings = model.encode(documents)
222
+
223
+ similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
224
+
225
+ results = [
226
+ {
227
+ 'document': doc,
228
+ 'similarity': float(score),
229
+ 'rank': i + 1
230
+ }
231
+ for i, (doc, score) in enumerate(
232
+ sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)
233
+ )
234
+ ]
235
+
236
+ inference_time = time.time() - start_time
237
+
238
+ return jsonify({
239
+ 'query': query,
240
+ 'results': results,
241
+ 'inference_time': f"{inference_time:.3f}s"
242
+ })
243
+ except Exception as e:
244
+ logger.error(f"Error in semantic_search: {str(e)}")
245
+ return jsonify({'error': str(e)}), 500
246
+
247
+ @app.route('/api/moderate', methods=['POST'])
248
+ def moderate_content():
249
+ try:
250
+ start_time = time.time()
251
+ data = request.get_json()
252
+ if not data or 'text' not in data:
253
+ return jsonify({'error': 'Missing "text" in request body'}), 400
254
+
255
+ text = data['text']
256
+ model, tokenizer = model_cache.get_moderation_model()
257
+
258
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
259
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
260
+
261
+ with torch.no_grad():
262
+ outputs = model(**inputs)
263
+ probabilities = F.softmax(outputs.logits, dim=-1)
264
+ toxic_prob = probabilities[0][1].item()
265
+
266
+ is_toxic = toxic_prob > 0.5
267
+ inference_time = time.time() - start_time
268
+
269
+ return jsonify({
270
+ 'text': text,
271
+ 'is_toxic': is_toxic,
272
+ 'toxicity_score': round(toxic_prob, 4),
273
+ 'status': 'toxic' if is_toxic else 'safe',
274
+ 'inference_time': f"{inference_time:.3f}s"
275
+ })
276
+ except Exception as e:
277
+ logger.error(f"Error in moderate_content: {str(e)}")
278
+ return jsonify({'error': str(e)}), 500
279
+
280
+ @app.route('/api/batch', methods=['POST'])
281
+ def batch_process():
282
+ try:
283
+ data = request.get_json()
284
+ if not data or 'requests' not in data:
285
+ return jsonify({'error': 'Missing "requests" in request body'}), 400
286
+
287
+ results = []
288
+ for req in data['requests']:
289
+ req_type = req.get('type')
290
+
291
+ if req_type == 'category':
292
+ model, tokenizer = model_cache.get_category_model()
293
+ input_text = f"categorize: {req['query']}"
294
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
295
+ with torch.no_grad():
296
+ outputs = model.generate(input_ids, max_length=32)
297
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
298
+ results.append({'type': 'category', 'result': result})
299
+
300
+ elif req_type == 'chat':
301
+ model, tokenizer = model_cache.get_chat_model()
302
+ input_text = f"answer question: {req['question']}"
303
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
304
+ with torch.no_grad():
305
+ outputs = model.generate(input_ids, max_length=256)
306
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
307
+ results.append({'type': 'chat', 'result': result})
308
+
309
+ return jsonify({'results': results})
310
+ except Exception as e:
311
+ logger.error(f"Error in batch_process: {str(e)}")
312
+ return jsonify({'error': str(e)}), 500
313
+
314
+ @app.errorhandler(404)
315
+ def not_found(error):
316
+ return jsonify({'error': 'Endpoint not found'}), 404
317
+
318
+ @app.errorhandler(500)
319
+ def internal_error(error):
320
+ return jsonify({'error': 'Internal server error'}), 500
321
+
322
+ if __name__ == '__main__':
323
+ port = int(os.getenv('PORT', 7860))
324
+ logger.info(f"Starting server on port {port}...")
325
+ app.run(host='0.0.0.0', port=port, debug=False)