yangg40 commited on
Commit
c9a1a62
·
verified ·
1 Parent(s): 9edd100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -16
app.py CHANGED
@@ -1,11 +1,13 @@
1
  """
2
- HuggingFace Space - Embedding API
3
- Lightweight stateless API for generating text embeddings
4
  """
5
 
6
  import os
7
  from flask import Flask, request, jsonify
8
- from sentence_transformers import SentenceTransformer
 
 
9
  import logging
10
 
11
  # Configure logging
@@ -15,13 +17,49 @@ logger = logging.getLogger(__name__)
15
  # Initialize Flask app
16
  app = Flask(__name__)
17
 
18
- # Load the embedding model (SPECTER2, 768 dimensions)
19
- # SPECTER2 is trained on scientific papers with citation relationships
20
- # Best-in-class for academic paper embeddings
21
- logger.info("Loading SPECTER2 embedding model...")
22
- model = SentenceTransformer('allenai/specter2')
 
 
 
23
  logger.info("SPECTER2 model loaded successfully!")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  @app.route('/')
27
  def health():
@@ -29,9 +67,11 @@ def health():
29
  return jsonify({
30
  "status": "healthy",
31
  "model": "allenai/specter2",
 
32
  "dimensions": 768,
 
33
  "endpoints": {
34
- "/embed": "POST - Generate embeddings for text",
35
  "/batch_embed": "POST - Generate embeddings for multiple texts"
36
  }
37
  })
@@ -44,7 +84,7 @@ def embed_text():
44
 
45
  Request body:
46
  {
47
- "text": "Your text here"
48
  }
49
 
50
  Response:
@@ -69,11 +109,11 @@ def embed_text():
69
  }), 400
70
 
71
  # Generate embedding
72
- embedding = model.encode(text, convert_to_numpy=True)
73
 
74
  return jsonify({
75
- "embedding": embedding.tolist(),
76
- "dimensions": len(embedding)
77
  })
78
 
79
  except Exception as e:
@@ -91,7 +131,7 @@ def batch_embed_texts():
91
 
92
  Request body:
93
  {
94
- "texts": ["Text 1", "Text 2", ...]
95
  }
96
 
97
  Response:
@@ -123,12 +163,12 @@ def batch_embed_texts():
123
  }), 400
124
 
125
  # Generate embeddings
126
- embeddings = model.encode(texts, convert_to_numpy=True)
127
 
128
  return jsonify({
129
  "embeddings": embeddings.tolist(),
130
  "count": len(embeddings),
131
- "dimensions": embeddings.shape[1] if len(embeddings) > 0 else 768
132
  })
133
 
134
  except Exception as e:
 
1
  """
2
+ HuggingFace Space - SPECTER2 Embedding API
3
+ Academic paper embeddings using SPECTER2 with adapters
4
  """
5
 
6
  import os
7
  from flask import Flask, request, jsonify
8
+ from transformers import AutoTokenizer
9
+ from adapters import AutoAdapterModel
10
+ import torch
11
  import logging
12
 
13
  # Configure logging
 
17
  # Initialize Flask app
18
  app = Flask(__name__)
19
 
20
+ # Load SPECTER2 model with adapters
21
+ logger.info("Loading SPECTER2 base model and tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
23
+ model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
24
+
25
+ logger.info("Loading SPECTER2 proximity adapter...")
26
+ # Load the proximity adapter for similarity/retrieval tasks
27
+ model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
28
  logger.info("SPECTER2 model loaded successfully!")
29
 
30
+ # Move to GPU if available
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ model = model.to(device)
33
+ logger.info(f"Using device: {device}")
34
+
35
+
36
+ def get_embeddings(texts):
37
+ """
38
+ Generate SPECTER2 embeddings for a list of texts
39
+
40
+ Args:
41
+ texts: List of strings (paper titles + abstracts)
42
+
43
+ Returns:
44
+ numpy array of embeddings (batch_size, 768)
45
+ """
46
+ # Tokenize
47
+ inputs = tokenizer(
48
+ texts,
49
+ padding=True,
50
+ truncation=True,
51
+ return_tensors="pt",
52
+ max_length=512
53
+ ).to(device)
54
+
55
+ # Generate embeddings
56
+ with torch.no_grad():
57
+ output = model(**inputs)
58
+ # Use [CLS] token embedding (first token)
59
+ embeddings = output.last_hidden_state[:, 0, :]
60
+
61
+ return embeddings.cpu().numpy()
62
+
63
 
64
  @app.route('/')
65
  def health():
 
67
  return jsonify({
68
  "status": "healthy",
69
  "model": "allenai/specter2",
70
+ "adapter": "proximity (similarity/retrieval)",
71
  "dimensions": 768,
72
+ "device": str(device),
73
  "endpoints": {
74
+ "/embed": "POST - Generate embedding for single text",
75
  "/batch_embed": "POST - Generate embeddings for multiple texts"
76
  }
77
  })
 
84
 
85
  Request body:
86
  {
87
+ "text": "Your paper title and abstract here"
88
  }
89
 
90
  Response:
 
109
  }), 400
110
 
111
  # Generate embedding
112
+ embeddings = get_embeddings([text])
113
 
114
  return jsonify({
115
+ "embedding": embeddings[0].tolist(),
116
+ "dimensions": len(embeddings[0])
117
  })
118
 
119
  except Exception as e:
 
131
 
132
  Request body:
133
  {
134
+ "texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...]
135
  }
136
 
137
  Response:
 
163
  }), 400
164
 
165
  # Generate embeddings
166
+ embeddings = get_embeddings(texts)
167
 
168
  return jsonify({
169
  "embeddings": embeddings.tolist(),
170
  "count": len(embeddings),
171
+ "dimensions": embeddings.shape[1]
172
  })
173
 
174
  except Exception as e: