Spaces:
Sleeping
Sleeping
added new route
Browse files- routes/study.py +22 -4
routes/study.py
CHANGED
|
@@ -6,7 +6,8 @@ from pydantic import BaseModel
|
|
| 6 |
|
| 7 |
router = APIRouter(prefix="/api/study")
|
| 8 |
|
| 9 |
-
# Existing Request Models
|
|
|
|
| 10 |
class NeighborhoodRequest(BaseModel):
|
| 11 |
word: str
|
| 12 |
n_neighbors: int = 20
|
|
@@ -29,7 +30,7 @@ class SemanticFieldRequest(BaseModel):
|
|
| 29 |
words: List[str]
|
| 30 |
n_neighbors: int = 5
|
| 31 |
|
| 32 |
-
# New Request Models for additional study operations
|
| 33 |
|
| 34 |
class PhraseRequest(BaseModel):
|
| 35 |
words: List[str]
|
|
@@ -58,6 +59,10 @@ class CombineRequest(BaseModel):
|
|
| 58 |
positive: List[WeightedWord]
|
| 59 |
negative: List[WeightedWord] = []
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def init_router(study_service: StudyService):
|
| 63 |
router = APIRouter(prefix="/api/study")
|
|
@@ -123,8 +128,6 @@ def init_router(study_service: StudyService):
|
|
| 123 |
logger.error(f"Error retrieving visualization data: {str(e)}")
|
| 124 |
raise HTTPException(status_code=500, detail="Internal server error")
|
| 125 |
|
| 126 |
-
# New Endpoints
|
| 127 |
-
|
| 128 |
@router.post("/phrase")
|
| 129 |
async def get_phrase_vector(request: PhraseRequest):
|
| 130 |
"""
|
|
@@ -201,5 +204,20 @@ def init_router(study_service: StudyService):
|
|
| 201 |
except Exception as e:
|
| 202 |
logger.error(f"Error combining word vectors: {str(e)}")
|
| 203 |
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
return router
|
|
|
|
| 6 |
|
| 7 |
router = APIRouter(prefix="/api/study")
|
| 8 |
|
| 9 |
+
# --- Existing Request Models ---
|
| 10 |
+
|
| 11 |
class NeighborhoodRequest(BaseModel):
|
| 12 |
word: str
|
| 13 |
n_neighbors: int = 20
|
|
|
|
| 30 |
words: List[str]
|
| 31 |
n_neighbors: int = 5
|
| 32 |
|
| 33 |
+
# --- New Request Models for additional study operations ---
|
| 34 |
|
| 35 |
class PhraseRequest(BaseModel):
|
| 36 |
words: List[str]
|
|
|
|
| 59 |
positive: List[WeightedWord]
|
| 60 |
negative: List[WeightedWord] = []
|
| 61 |
|
| 62 |
+
# --- New Request Model for Similar By Vector Endpoint ---
|
| 63 |
+
class SimilarByVectorRequest(BaseModel):
|
| 64 |
+
vector: List[float]
|
| 65 |
+
n: int = 10
|
| 66 |
|
| 67 |
def init_router(study_service: StudyService):
|
| 68 |
router = APIRouter(prefix="/api/study")
|
|
|
|
| 128 |
logger.error(f"Error retrieving visualization data: {str(e)}")
|
| 129 |
raise HTTPException(status_code=500, detail="Internal server error")
|
| 130 |
|
|
|
|
|
|
|
| 131 |
@router.post("/phrase")
|
| 132 |
async def get_phrase_vector(request: PhraseRequest):
|
| 133 |
"""
|
|
|
|
| 204 |
except Exception as e:
|
| 205 |
logger.error(f"Error combining word vectors: {str(e)}")
|
| 206 |
raise HTTPException(status_code=500, detail="Internal server error")
|
| 207 |
+
|
| 208 |
+
# --- New Endpoint: Similar By Vector ---
|
| 209 |
+
@router.post("/similar-by-vector")
|
| 210 |
+
async def similar_by_vector_endpoint(request: SimilarByVectorRequest):
|
| 211 |
+
"""
|
| 212 |
+
Given a vector (list of floats) and a number n, return the n words most similar to that vector.
|
| 213 |
+
"""
|
| 214 |
+
try:
|
| 215 |
+
import numpy as np
|
| 216 |
+
vector = np.array(request.vector)
|
| 217 |
+
similar = await study_service.word_service.get_similar_by_vector(vector, n=request.n)
|
| 218 |
+
return {"similar_words": similar}
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error(f"Error computing similar words by vector: {str(e)}")
|
| 221 |
+
raise HTTPException(status_code=500, detail="Internal server error")
|
| 222 |
|
| 223 |
return router
|