Miroir commited on
Commit
b36798e
·
1 Parent(s): 60a624f

added new route

Browse files
Files changed (1) hide show
  1. routes/study.py +22 -4
routes/study.py CHANGED
@@ -6,7 +6,8 @@ from pydantic import BaseModel
6
 
7
  router = APIRouter(prefix="/api/study")
8
 
9
- # Existing Request Models
 
10
  class NeighborhoodRequest(BaseModel):
11
  word: str
12
  n_neighbors: int = 20
@@ -29,7 +30,7 @@ class SemanticFieldRequest(BaseModel):
29
  words: List[str]
30
  n_neighbors: int = 5
31
 
32
- # New Request Models for additional study operations
33
 
34
  class PhraseRequest(BaseModel):
35
  words: List[str]
@@ -58,6 +59,10 @@ class CombineRequest(BaseModel):
58
  positive: List[WeightedWord]
59
  negative: List[WeightedWord] = []
60
 
 
 
 
 
61
 
62
  def init_router(study_service: StudyService):
63
  router = APIRouter(prefix="/api/study")
@@ -123,8 +128,6 @@ def init_router(study_service: StudyService):
123
  logger.error(f"Error retrieving visualization data: {str(e)}")
124
  raise HTTPException(status_code=500, detail="Internal server error")
125
 
126
- # New Endpoints
127
-
128
  @router.post("/phrase")
129
  async def get_phrase_vector(request: PhraseRequest):
130
  """
@@ -201,5 +204,20 @@ def init_router(study_service: StudyService):
201
  except Exception as e:
202
  logger.error(f"Error combining word vectors: {str(e)}")
203
  raise HTTPException(status_code=500, detail="Internal server error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  return router
 
6
 
7
  router = APIRouter(prefix="/api/study")
8
 
9
+ # --- Existing Request Models ---
10
+
11
  class NeighborhoodRequest(BaseModel):
12
  word: str
13
  n_neighbors: int = 20
 
30
  words: List[str]
31
  n_neighbors: int = 5
32
 
33
+ # --- New Request Models for additional study operations ---
34
 
35
  class PhraseRequest(BaseModel):
36
  words: List[str]
 
59
  positive: List[WeightedWord]
60
  negative: List[WeightedWord] = []
61
 
62
+ # --- New Request Model for Similar By Vector Endpoint ---
63
+ class SimilarByVectorRequest(BaseModel):
64
+ vector: List[float]
65
+ n: int = 10
66
 
67
  def init_router(study_service: StudyService):
68
  router = APIRouter(prefix="/api/study")
 
128
  logger.error(f"Error retrieving visualization data: {str(e)}")
129
  raise HTTPException(status_code=500, detail="Internal server error")
130
 
 
 
131
  @router.post("/phrase")
132
  async def get_phrase_vector(request: PhraseRequest):
133
  """
 
204
  except Exception as e:
205
  logger.error(f"Error combining word vectors: {str(e)}")
206
  raise HTTPException(status_code=500, detail="Internal server error")
207
+
208
+ # --- New Endpoint: Similar By Vector ---
209
+ @router.post("/similar-by-vector")
210
+ async def similar_by_vector_endpoint(request: SimilarByVectorRequest):
211
+ """
212
+ Given a vector (list of floats) and a number n, return the n words most similar to that vector.
213
+ """
214
+ try:
215
+ import numpy as np
216
+ vector = np.array(request.vector)
217
+ similar = await study_service.word_service.get_similar_by_vector(vector, n=request.n)
218
+ return {"similar_words": similar}
219
+ except Exception as e:
220
+ logger.error(f"Error computing similar words by vector: {str(e)}")
221
+ raise HTTPException(status_code=500, detail="Internal server error")
222
 
223
  return router