SyedSyab commited on
Commit
7b031dc
·
1 Parent(s): b3f4657

Update ML service to use transformers instead of sentence-transformers for compatibility

Browse files
Files changed (2) hide show
  1. app.py +25 -8
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,11 +1,25 @@
1
  from flask import Flask, request, jsonify
2
- from sentence_transformers import SentenceTransformer
 
3
  import os
4
 
5
  app = Flask(__name__)
6
 
7
  # Load your model once
8
- model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  @app.route('/api/predict', methods=['POST'])
11
  def predict():
@@ -18,13 +32,16 @@ def predict():
18
  if not isinstance(texts, list):
19
  return jsonify({'error': 'Data must be a list of texts'}), 400
20
 
21
- # Generate embeddings
22
- embeddings = model.encode(texts, normalize_embeddings=True)
 
 
 
 
 
 
23
 
24
- # Convert to list format
25
- embeddings_list = embeddings.tolist() if hasattr(embeddings, 'tolist') else embeddings
26
-
27
- return jsonify({'data': embeddings_list})
28
 
29
  except Exception as e:
30
  return jsonify({'error': str(e)}), 500
 
1
  from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
  import os
5
 
6
  app = Flask(__name__)
7
 
8
  # Load your model once
9
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModel.from_pretrained(model_name)
12
+
13
+ def get_embedding(text):
14
+ """Generate embedding for a single text"""
15
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+ # Use mean pooling over token embeddings
19
+ embeddings = outputs.last_hidden_state.mean(dim=1)
20
+ # Normalize the embeddings
21
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
22
+ return embeddings.squeeze().tolist()
23
 
24
  @app.route('/api/predict', methods=['POST'])
25
  def predict():
 
32
  if not isinstance(texts, list):
33
  return jsonify({'error': 'Data must be a list of texts'}), 400
34
 
35
+ # Generate embeddings for each text
36
+ embeddings = []
37
+ for text in texts:
38
+ if isinstance(text, str):
39
+ embedding = get_embedding(text)
40
+ embeddings.append(embedding)
41
+ else:
42
+ return jsonify({'error': 'All items in data must be strings'}), 400
43
 
44
+ return jsonify({'data': embeddings})
 
 
 
45
 
46
  except Exception as e:
47
  return jsonify({'error': str(e)}), 500
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  Flask==2.3.3
2
- sentence-transformers==2.2.2
3
  torch>=2.0.0
4
  numpy>=1.21.0
 
1
  Flask==2.3.3
2
+ transformers==4.36.0
3
  torch>=2.0.0
4
  numpy>=1.21.0