Spaces:
Sleeping
Sleeping
File size: 8,485 Bytes
73390f3 1f3b631 c025f35 1f3b631 b36798e c025f35 73390f3 c025f35 73390f3 c025f35 73390f3 c025f35 73390f3 c025f35 b36798e c64867f b36798e c64867f 1f3b631 73390f3 1f3b631 73390f3 1f3b631 73390f3 1f3b631 c025f35 1f3b631 73390f3 c025f35 73390f3 c025f35 1f3b631 c025f35 1f3b631 73390f3 c64867f 73390f3 c64867f 1f3b631 c64867f 1f3b631 c64867f 1f3b631 b36798e 1f3b631 c64867f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | from services.study_service import StudyService
from fastapi import APIRouter, HTTPException
from loguru import logger
from typing import List
from pydantic import BaseModel
router = APIRouter(prefix="/api/study")
# --- Existing Request Models ---
class NeighborhoodRequest(BaseModel):
word: str
n_neighbors: int = 20
class VisualizationRequest(BaseModel):
words: List[str]
class ConceptRequest(BaseModel):
positive_words: List[str]
negative_words: List[str] = []
n_results: int = 10
class AnalogyRequest(BaseModel):
word1: str
word2: str
word3: str
n_results: int = 10
class SemanticFieldRequest(BaseModel):
words: List[str]
n_neighbors: int = 5
# --- New Request Models for additional study operations ---
class PhraseRequest(BaseModel):
words: List[str]
class ClusterRequest(BaseModel):
words: List[str]
n_clusters: int = 3
class OutlierRequest(BaseModel):
words: List[str]
class DistributionRequest(BaseModel):
word: str
sample_size: int = 1000
class InterpolationRequest(BaseModel):
word1: str
word2: str
steps: int = 5
class WeightedWord(BaseModel):
word: str
weight: float
class CombineRequest(BaseModel):
positive: List[WeightedWord]
negative: List[WeightedWord] = []
# --- New Request Model for Similar By Vector Endpoint ---
class SimilarByVectorRequest(BaseModel):
vector: List[float]
n: int = 10
def init_router(study_service: StudyService):
router = APIRouter(prefix="/api/study")
@router.post("/concept")
async def analyze_concept(request: ConceptRequest):
try:
return await study_service.analyze_concept(
request.positive_words,
request.negative_words,
request.n_results
)
except Exception as e:
logger.error(f"Error analyzing concept: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/analogy")
async def analyze_analogy(request: AnalogyRequest):
try:
return await study_service.analyze_analogy(
request.word1,
request.word2,
request.word3,
request.n_results
)
except Exception as e:
logger.error(f"Error analyzing analogy: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/semantic-field")
async def analyze_semantic_field(request: SemanticFieldRequest):
try:
return await study_service.analyze_semantic_field(
request.words,
request.n_neighbors
)
except Exception as e:
logger.error(f"Error analyzing semantic field: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/neighborhood")
async def analyze_neighborhood(request: NeighborhoodRequest):
"""Analyze word neighborhood with detailed semantic information"""
try:
return await study_service.analyze_word_neighborhood(
request.word,
request.n_neighbors
)
except Exception as e:
logger.error(f"Error analyzing word neighborhood: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/visualization")
async def get_visualization_data(request: VisualizationRequest):
"""
Retrieve the words along with their raw vector representations.
The external visualization service will receive these vectors and
perform the projection (e.g. to 3D) as needed.
"""
try:
return await study_service.get_word_vectors(request.words)
except Exception as e:
logger.error(f"Error retrieving visualization data: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/phrase")
async def get_phrase_vector(request: PhraseRequest):
"""
Compute and return the averaged embedding for a list of words (phrase).
"""
try:
vector = await study_service.get_phrase_vector(request.words)
if vector is None:
raise HTTPException(status_code=404, detail="No valid vectors found for given words.")
return {"phrase_vector": vector}
except Exception as e:
logger.error(f"Error computing phrase vector: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/cluster")
async def cluster_words(request: ClusterRequest):
"""
Cluster the embeddings of the given words using K-Means.
Returns clusters and centroids.
"""
try:
return await study_service.cluster_words(request.words, request.n_clusters)
except Exception as e:
logger.error(f"Error clustering words: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/outlier")
async def find_outlier(request: OutlierRequest):
"""
Identify the outlier (the word least similar to the rest) in a list of words.
"""
try:
return await study_service.find_outlier(request.words)
except Exception as e:
logger.error(f"Error finding outlier: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/distribution")
async def distance_distribution(request: DistributionRequest):
"""
Compute the distribution of cosine similarities between the target word and a sample of words.
"""
try:
return await study_service.distance_distribution(request.word, request.sample_size)
except Exception as e:
logger.error(f"Error computing distance distribution: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/interpolate")
async def interpolate_words(request: InterpolationRequest):
"""
Generate a series of intermediate vectors between two words and retrieve the closest word for each step.
"""
try:
return await study_service.interpolate_words(request.word1, request.word2, request.steps)
except Exception as e:
logger.error(f"Error interpolating words: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/combine")
async def combine_word_vectors(request: CombineRequest):
"""
Combine word vectors given weighted positive and negative contributions.
Returns the combined normalized vector.
"""
try:
combined_vector = await study_service.combine_word_vectors(
positive=[(item.word, item.weight) for item in request.positive],
negative=[(item.word, item.weight) for item in request.negative]
)
if combined_vector is None:
raise HTTPException(status_code=404, detail="Could not compute combined vector.")
return {"combined_vector": combined_vector}
except Exception as e:
logger.error(f"Error combining word vectors: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
# --- New Endpoint: Similar By Vector ---
@router.post("/similar-by-vector")
async def similar_by_vector_endpoint(request: SimilarByVectorRequest):
"""
Given a vector (list of floats) and a number n, return the n words most similar to that vector.
"""
try:
import numpy as np
vector = np.array(request.vector)
similar = await study_service.word_service.get_similar_by_vector(vector, n=request.n)
return {"similar_words": similar}
except Exception as e:
logger.error(f"Error computing similar words by vector: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
return router
|