yogami9 commited on
Commit
2540970
·
verified ·
1 Parent(s): ee5993a

Deploy NEED AI API - app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -231
app.py CHANGED
@@ -1,23 +1,13 @@
1
  #!/usr/bin/env python3
2
- """
3
- NEED AI - Production Flask API with Direct Model Loading
4
- """
5
-
6
  from flask import Flask, request, jsonify
7
  from flask_cors import CORS
8
- from transformers import (
9
- T5ForConditionalGeneration,
10
- T5Tokenizer,
11
- AutoModelForSequenceClassification,
12
- AutoTokenizer
13
- )
14
  from sentence_transformers import SentenceTransformer
15
  import torch
16
  import torch.nn.functional as F
17
  from sklearn.metrics.pairwise import cosine_similarity
18
  import logging
19
  import os
20
- from functools import lru_cache
21
  import time
22
 
23
  logging.basicConfig(level=logging.INFO)
@@ -26,300 +16,130 @@ logger = logging.getLogger(__name__)
26
  app = Flask(__name__)
27
  CORS(app)
28
 
29
- HF_USERNAME = os.getenv("HF_USERNAME", "yogami9")
30
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
- logger.info(f"Using device: {DEVICE}")
32
 
33
  class ModelCache:
34
  def __init__(self):
35
  self.models = {}
36
  self.tokenizers = {}
37
- logger.info("Model cache initialized")
38
-
39
- @lru_cache(maxsize=1)
40
  def get_category_model(self):
41
  if 'category' not in self.models:
42
  logger.info("Loading Category model...")
43
- model_name = f"{HF_USERNAME}/need-category-recommendation"
44
- self.models['category'] = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
45
- self.tokenizers['category'] = T5Tokenizer.from_pretrained(model_name)
46
- logger.info("✅ Category model loaded")
47
  return self.models['category'], self.tokenizers['category']
48
 
49
- @lru_cache(maxsize=1)
50
  def get_chat_model(self):
51
  if 'chat' not in self.models:
52
  logger.info("Loading Chat model...")
53
- model_name = f"{HF_USERNAME}/need-chat-support"
54
- self.models['chat'] = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
55
- self.tokenizers['chat'] = T5Tokenizer.from_pretrained(model_name)
56
- logger.info("✅ Chat model loaded")
57
  return self.models['chat'], self.tokenizers['chat']
58
 
59
- @lru_cache(maxsize=1)
60
  def get_service_model(self):
61
  if 'service' not in self.models:
62
  logger.info("Loading Service model...")
63
- model_name = f"{HF_USERNAME}/need-service-description"
64
- self.models['service'] = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
65
- self.tokenizers['service'] = T5Tokenizer.from_pretrained(model_name)
66
- logger.info("✅ Service model loaded")
67
  return self.models['service'], self.tokenizers['service']
68
 
69
- @lru_cache(maxsize=1)
70
  def get_search_model(self):
71
  if 'search' not in self.models:
72
  logger.info("Loading Search model...")
73
- model_name = f"{HF_USERNAME}/need-semantic-search"
74
- self.models['search'] = SentenceTransformer(model_name)
75
- logger.info("✅ Search model loaded")
76
  return self.models['search']
77
 
78
- @lru_cache(maxsize=1)
79
  def get_moderation_model(self):
80
  if 'moderation' not in self.models:
81
  logger.info("Loading Moderation model...")
82
- model_name = f"{HF_USERNAME}/need-content-moderation"
83
- self.models['moderation'] = AutoModelForSequenceClassification.from_pretrained(model_name).to(DEVICE)
84
- self.tokenizers['moderation'] = AutoTokenizer.from_pretrained(model_name)
85
- logger.info("✅ Moderation model loaded")
86
  return self.models['moderation'], self.tokenizers['moderation']
87
 
88
- model_cache = ModelCache()
89
 
90
- @app.route('/', methods=['GET'])
91
  def home():
92
- return jsonify({
93
- 'name': 'NEED AI API',
94
- 'version': '1.0.0',
95
- 'status': 'running',
96
- 'endpoints': {
97
- 'health': '/health',
98
- 'category': '/api/category',
99
- 'chat': '/api/chat',
100
- 'service': '/api/service',
101
- 'search': '/api/search',
102
- 'moderate': '/api/moderate',
103
- 'batch': '/api/batch'
104
- },
105
- 'documentation': 'https://github.com/Need-Service-App/need-ai-model'
106
- })
107
 
108
- @app.route('/health', methods=['GET'])
109
  def health():
110
- return jsonify({
111
- 'status': 'healthy',
112
- 'device': str(DEVICE),
113
- 'models_loaded': len(model_cache.models),
114
- 'gpu_available': torch.cuda.is_available()
115
- })
116
 
117
  @app.route('/api/category', methods=['POST'])
118
  def predict_category():
119
  try:
120
- start_time = time.time()
121
- data = request.get_json()
122
- if not data or 'query' not in data:
123
- return jsonify({'error': 'Missing "query" in request body'}), 400
124
-
125
- query = data['query']
126
- model, tokenizer = model_cache.get_category_model()
127
-
128
- input_text = f"categorize: {query}"
129
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
130
-
131
  with torch.no_grad():
132
- outputs = model.generate(input_ids, max_length=32, num_beams=4, early_stopping=True)
133
-
134
  category = tokenizer.decode(outputs[0], skip_special_tokens=True)
135
- inference_time = time.time() - start_time
136
-
137
- return jsonify({
138
- 'query': query,
139
- 'category': category,
140
- 'inference_time': f"{inference_time:.3f}s"
141
- })
142
  except Exception as e:
143
- logger.error(f"Error in predict_category: {str(e)}")
144
  return jsonify({'error': str(e)}), 500
145
 
146
  @app.route('/api/chat', methods=['POST'])
147
  def answer_question():
148
  try:
149
- start_time = time.time()
150
- data = request.get_json()
151
- if not data or 'question' not in data:
152
- return jsonify({'error': 'Missing "question" in request body'}), 400
153
-
154
- question = data['question']
155
- model, tokenizer = model_cache.get_chat_model()
156
-
157
- input_text = f"answer question: {question}"
158
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
159
-
160
  with torch.no_grad():
161
- outputs = model.generate(input_ids, max_length=256, num_beams=4, early_stopping=True)
162
-
163
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
164
- inference_time = time.time() - start_time
165
-
166
- return jsonify({
167
- 'question': question,
168
- 'answer': answer,
169
- 'inference_time': f"{inference_time:.3f}s"
170
- })
171
  except Exception as e:
172
- logger.error(f"Error in answer_question: {str(e)}")
173
  return jsonify({'error': str(e)}), 500
174
 
175
  @app.route('/api/service', methods=['POST'])
176
  def generate_description():
177
  try:
178
- start_time = time.time()
179
- data = request.get_json()
180
- if not data or 'service_info' not in data:
181
- return jsonify({'error': 'Missing "service_info" in request body'}), 400
182
-
183
- service_info = data['service_info']
184
- model, tokenizer = model_cache.get_service_model()
185
-
186
- input_text = f"generate professional description: {service_info}"
187
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
188
-
189
  with torch.no_grad():
190
- outputs = model.generate(input_ids, max_length=512, num_beams=4, early_stopping=True)
191
-
192
  description = tokenizer.decode(outputs[0], skip_special_tokens=True)
193
- inference_time = time.time() - start_time
194
-
195
- return jsonify({
196
- 'service_info': service_info,
197
- 'description': description,
198
- 'inference_time': f"{inference_time:.3f}s"
199
- })
200
  except Exception as e:
201
- logger.error(f"Error in generate_description: {str(e)}")
202
  return jsonify({'error': str(e)}), 500
203
 
204
  @app.route('/api/search', methods=['POST'])
205
  def semantic_search():
206
  try:
207
- start_time = time.time()
208
- data = request.get_json()
209
- if not data or 'query' not in data or 'documents' not in data:
210
- return jsonify({'error': 'Missing "query" or "documents" in request body'}), 400
211
-
212
- query = data['query']
213
- documents = data['documents']
214
-
215
- if not isinstance(documents, list):
216
- return jsonify({'error': '"documents" must be a list'}), 400
217
-
218
- model = model_cache.get_search_model()
219
-
220
- query_embedding = model.encode([query])
221
- doc_embeddings = model.encode(documents)
222
-
223
- similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
224
-
225
- results = [
226
- {
227
- 'document': doc,
228
- 'similarity': float(score),
229
- 'rank': i + 1
230
- }
231
- for i, (doc, score) in enumerate(
232
- sorted(zip(documents, similarities), key=lambda x: x[1], reverse=True)
233
- )
234
- ]
235
-
236
- inference_time = time.time() - start_time
237
-
238
- return jsonify({
239
- 'query': query,
240
- 'results': results,
241
- 'inference_time': f"{inference_time:.3f}s"
242
- })
243
  except Exception as e:
244
- logger.error(f"Error in semantic_search: {str(e)}")
245
  return jsonify({'error': str(e)}), 500
246
 
247
  @app.route('/api/moderate', methods=['POST'])
248
  def moderate_content():
249
  try:
250
- start_time = time.time()
251
- data = request.get_json()
252
- if not data or 'text' not in data:
253
- return jsonify({'error': 'Missing "text" in request body'}), 400
254
-
255
- text = data['text']
256
- model, tokenizer = model_cache.get_moderation_model()
257
-
258
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
259
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
260
-
261
  with torch.no_grad():
262
  outputs = model(**inputs)
263
- probabilities = F.softmax(outputs.logits, dim=-1)
264
- toxic_prob = probabilities[0][1].item()
265
-
266
- is_toxic = toxic_prob > 0.5
267
- inference_time = time.time() - start_time
268
-
269
- return jsonify({
270
- 'text': text,
271
- 'is_toxic': is_toxic,
272
- 'toxicity_score': round(toxic_prob, 4),
273
- 'status': 'toxic' if is_toxic else 'safe',
274
- 'inference_time': f"{inference_time:.3f}s"
275
- })
276
- except Exception as e:
277
- logger.error(f"Error in moderate_content: {str(e)}")
278
- return jsonify({'error': str(e)}), 500
279
-
280
- @app.route('/api/batch', methods=['POST'])
281
- def batch_process():
282
- try:
283
- data = request.get_json()
284
- if not data or 'requests' not in data:
285
- return jsonify({'error': 'Missing "requests" in request body'}), 400
286
-
287
- results = []
288
- for req in data['requests']:
289
- req_type = req.get('type')
290
-
291
- if req_type == 'category':
292
- model, tokenizer = model_cache.get_category_model()
293
- input_text = f"categorize: {req['query']}"
294
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
295
- with torch.no_grad():
296
- outputs = model.generate(input_ids, max_length=32)
297
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
298
- results.append({'type': 'category', 'result': result})
299
-
300
- elif req_type == 'chat':
301
- model, tokenizer = model_cache.get_chat_model()
302
- input_text = f"answer question: {req['question']}"
303
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(DEVICE)
304
- with torch.no_grad():
305
- outputs = model.generate(input_ids, max_length=256)
306
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
307
- results.append({'type': 'chat', 'result': result})
308
-
309
- return jsonify({'results': results})
310
  except Exception as e:
311
- logger.error(f"Error in batch_process: {str(e)}")
312
  return jsonify({'error': str(e)}), 500
313
 
314
- @app.errorhandler(404)
315
- def not_found(error):
316
- return jsonify({'error': 'Endpoint not found'}), 404
317
-
318
- @app.errorhandler(500)
319
- def internal_error(error):
320
- return jsonify({'error': 'Internal server error'}), 500
321
-
322
  if __name__ == '__main__':
323
- port = int(os.getenv('PORT', 7860))
324
- logger.info(f"Starting server on port {port}...")
325
- app.run(host='0.0.0.0', port=port, debug=False)
 
1
  #!/usr/bin/env python3
 
 
 
 
2
  from flask import Flask, request, jsonify
3
  from flask_cors import CORS
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoModelForSequenceClassification, AutoTokenizer
 
 
 
 
 
5
  from sentence_transformers import SentenceTransformer
6
  import torch
7
  import torch.nn.functional as F
8
  from sklearn.metrics.pairwise import cosine_similarity
9
  import logging
10
  import os
 
11
  import time
12
 
13
  logging.basicConfig(level=logging.INFO)
 
16
  app = Flask(__name__)
17
  CORS(app)
18
 
19
+ HF_USERNAME = "yogami9"
20
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
21
 
22
  class ModelCache:
23
  def __init__(self):
24
  self.models = {}
25
  self.tokenizers = {}
26
+
 
 
27
  def get_category_model(self):
28
  if 'category' not in self.models:
29
  logger.info("Loading Category model...")
30
+ self.models['category'] = T5ForConditionalGeneration.from_pretrained(f"{HF_USERNAME}/need-category-recommendation").to(DEVICE)
31
+ self.tokenizers['category'] = T5Tokenizer.from_pretrained(f"{HF_USERNAME}/need-category-recommendation")
 
 
32
  return self.models['category'], self.tokenizers['category']
33
 
 
34
  def get_chat_model(self):
35
  if 'chat' not in self.models:
36
  logger.info("Loading Chat model...")
37
+ self.models['chat'] = T5ForConditionalGeneration.from_pretrained(f"{HF_USERNAME}/need-chat-support").to(DEVICE)
38
+ self.tokenizers['chat'] = T5Tokenizer.from_pretrained(f"{HF_USERNAME}/need-chat-support")
 
 
39
  return self.models['chat'], self.tokenizers['chat']
40
 
 
41
  def get_service_model(self):
42
  if 'service' not in self.models:
43
  logger.info("Loading Service model...")
44
+ self.models['service'] = T5ForConditionalGeneration.from_pretrained(f"{HF_USERNAME}/need-service-description").to(DEVICE)
45
+ self.tokenizers['service'] = T5Tokenizer.from_pretrained(f"{HF_USERNAME}/need-service-description")
 
 
46
  return self.models['service'], self.tokenizers['service']
47
 
 
48
  def get_search_model(self):
49
  if 'search' not in self.models:
50
  logger.info("Loading Search model...")
51
+ self.models['search'] = SentenceTransformer(f"{HF_USERNAME}/need-semantic-search")
 
 
52
  return self.models['search']
53
 
 
54
  def get_moderation_model(self):
55
  if 'moderation' not in self.models:
56
  logger.info("Loading Moderation model...")
57
+ self.models['moderation'] = AutoModelForSequenceClassification.from_pretrained(f"{HF_USERNAME}/need-content-moderation").to(DEVICE)
58
+ self.tokenizers['moderation'] = AutoTokenizer.from_pretrained(f"{HF_USERNAME}/need-content-moderation")
 
 
59
  return self.models['moderation'], self.tokenizers['moderation']
60
 
61
+ cache = ModelCache()
62
 
63
+ @app.route('/')
64
  def home():
65
+ return jsonify({'name': 'NEED AI API', 'status': 'running', 'models': 5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ @app.route('/health')
68
  def health():
69
+ return jsonify({'status': 'healthy', 'models_loaded': len(cache.models)})
 
 
 
 
 
70
 
71
  @app.route('/api/category', methods=['POST'])
72
  def predict_category():
73
  try:
74
+ data = request.json
75
+ query = data.get('query', '')
76
+ model, tokenizer = cache.get_category_model()
77
+ input_ids = tokenizer.encode(f"categorize: {query}", return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
78
  with torch.no_grad():
79
+ outputs = model.generate(input_ids, max_length=32)
 
80
  category = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+ return jsonify({'query': query, 'category': category})
 
 
 
 
 
 
82
  except Exception as e:
 
83
  return jsonify({'error': str(e)}), 500
84
 
85
  @app.route('/api/chat', methods=['POST'])
86
  def answer_question():
87
  try:
88
+ data = request.json
89
+ question = data.get('question', '')
90
+ model, tokenizer = cache.get_chat_model()
91
+ input_ids = tokenizer.encode(f"answer question: {question}", return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
92
  with torch.no_grad():
93
+ outputs = model.generate(input_ids, max_length=256)
 
94
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+ return jsonify({'question': question, 'answer': answer})
 
 
 
 
 
 
96
  except Exception as e:
 
97
  return jsonify({'error': str(e)}), 500
98
 
99
  @app.route('/api/service', methods=['POST'])
100
  def generate_description():
101
  try:
102
+ data = request.json
103
+ service_info = data.get('service_info', '')
104
+ model, tokenizer = cache.get_service_model()
105
+ input_ids = tokenizer.encode(f"generate professional description: {service_info}", return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
106
  with torch.no_grad():
107
+ outputs = model.generate(input_ids, max_length=512)
 
108
  description = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
+ return jsonify({'service_info': service_info, 'description': description})
 
 
 
 
 
 
110
  except Exception as e:
 
111
  return jsonify({'error': str(e)}), 500
112
 
113
  @app.route('/api/search', methods=['POST'])
114
  def semantic_search():
115
  try:
116
+ data = request.json
117
+ query = data.get('query', '')
118
+ documents = data.get('documents', [])
119
+ model = cache.get_search_model()
120
+ query_emb = model.encode([query])
121
+ doc_embs = model.encode(documents)
122
+ sims = cosine_similarity(query_emb, doc_embs)[0]
123
+ results = [{'document': d, 'similarity': float(s), 'rank': i+1} for i, (d, s) in enumerate(sorted(zip(documents, sims), key=lambda x: x[1], reverse=True))]
124
+ return jsonify({'query': query, 'results': results})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
 
126
  return jsonify({'error': str(e)}), 500
127
 
128
  @app.route('/api/moderate', methods=['POST'])
129
  def moderate_content():
130
  try:
131
+ data = request.json
132
+ text = data.get('text', '')
133
+ model, tokenizer = cache.get_moderation_model()
 
 
 
 
 
134
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
135
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
136
  with torch.no_grad():
137
  outputs = model(**inputs)
138
+ probs = F.softmax(outputs.logits, dim=-1)
139
+ toxic_prob = probs[0][1].item()
140
+ return jsonify({'text': text, 'is_toxic': toxic_prob > 0.5, 'toxicity_score': round(toxic_prob, 4), 'status': 'toxic' if toxic_prob > 0.5 else 'safe'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
 
142
  return jsonify({'error': str(e)}), 500
143
 
 
 
 
 
 
 
 
 
144
  if __name__ == '__main__':
145
+ app.run(host='0.0.0.0', port=int(os.getenv('PORT', 7860)))