yangg40 commited on
Commit
5bae6d2
·
verified ·
1 Parent(s): 4f66e66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -15,11 +15,12 @@ logger = logging.getLogger(__name__)
15
  # Initialize Flask app
16
  app = Flask(__name__)
17
 
18
- # Load the embedding model (all-MiniLM-L6-v2, 384 dimensions)
19
- # This model is small (~80MB) and fast on CPU
20
- logger.info("Loading embedding model...")
21
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
22
- logger.info("Model loaded successfully!")
 
23
 
24
 
25
  @app.route('/')
@@ -27,8 +28,8 @@ def health():
27
  """Health check endpoint"""
28
  return jsonify({
29
  "status": "healthy",
30
- "model": "all-MiniLM-L6-v2",
31
- "dimensions": 384,
32
  "endpoints": {
33
  "/embed": "POST - Generate embeddings for text",
34
  "/batch_embed": "POST - Generate embeddings for multiple texts"
@@ -49,7 +50,7 @@ def embed_text():
49
  Response:
50
  {
51
  "embedding": [0.123, -0.456, ...],
52
- "dimensions": 384
53
  }
54
  """
55
  try:
@@ -97,7 +98,7 @@ def batch_embed_texts():
97
  {
98
  "embeddings": [[0.123, ...], [0.456, ...], ...],
99
  "count": 2,
100
- "dimensions": 384
101
  }
102
  """
103
  try:
@@ -127,7 +128,7 @@ def batch_embed_texts():
127
  return jsonify({
128
  "embeddings": embeddings.tolist(),
129
  "count": len(embeddings),
130
- "dimensions": embeddings.shape[1] if len(embeddings) > 0 else 384
131
  })
132
 
133
  except Exception as e:
 
15
  # Initialize Flask app
16
  app = Flask(__name__)
17
 
18
+ # Load the embedding model (SPECTER2, 768 dimensions)
19
+ # SPECTER2 is trained on scientific papers with citation relationships
20
+ # Best-in-class for academic paper embeddings
21
+ logger.info("Loading SPECTER2 embedding model...")
22
+ model = SentenceTransformer('allenai/specter2')
23
+ logger.info("SPECTER2 model loaded successfully!")
24
 
25
 
26
  @app.route('/')
 
28
  """Health check endpoint"""
29
  return jsonify({
30
  "status": "healthy",
31
+ "model": "allenai/specter2",
32
+ "dimensions": 768,
33
  "endpoints": {
34
  "/embed": "POST - Generate embeddings for text",
35
  "/batch_embed": "POST - Generate embeddings for multiple texts"
 
50
  Response:
51
  {
52
  "embedding": [0.123, -0.456, ...],
53
+ "dimensions": 768
54
  }
55
  """
56
  try:
 
98
  {
99
  "embeddings": [[0.123, ...], [0.456, ...], ...],
100
  "count": 2,
101
+ "dimensions": 768
102
  }
103
  """
104
  try:
 
128
  return jsonify({
129
  "embeddings": embeddings.tolist(),
130
  "count": len(embeddings),
131
+ "dimensions": embeddings.shape[1] if len(embeddings) > 0 else 768
132
  })
133
 
134
  except Exception as e: