Spaces:
Sleeping
Sleeping
| from services.study_service import StudyService | |
| from fastapi import APIRouter, HTTPException | |
| from loguru import logger | |
| from typing import List | |
| from pydantic import BaseModel | |
| router = APIRouter(prefix="/api/study") | |
| # --- Existing Request Models --- | |
| class NeighborhoodRequest(BaseModel): | |
| word: str | |
| n_neighbors: int = 20 | |
| class VisualizationRequest(BaseModel): | |
| words: List[str] | |
| class ConceptRequest(BaseModel): | |
| positive_words: List[str] | |
| negative_words: List[str] = [] | |
| n_results: int = 10 | |
| class AnalogyRequest(BaseModel): | |
| word1: str | |
| word2: str | |
| word3: str | |
| n_results: int = 10 | |
| class SemanticFieldRequest(BaseModel): | |
| words: List[str] | |
| n_neighbors: int = 5 | |
| # --- New Request Models for additional study operations --- | |
| class PhraseRequest(BaseModel): | |
| words: List[str] | |
| class ClusterRequest(BaseModel): | |
| words: List[str] | |
| n_clusters: int = 3 | |
| class OutlierRequest(BaseModel): | |
| words: List[str] | |
| class DistributionRequest(BaseModel): | |
| word: str | |
| sample_size: int = 1000 | |
| class InterpolationRequest(BaseModel): | |
| word1: str | |
| word2: str | |
| steps: int = 5 | |
| class WeightedWord(BaseModel): | |
| word: str | |
| weight: float | |
| class CombineRequest(BaseModel): | |
| positive: List[WeightedWord] | |
| negative: List[WeightedWord] = [] | |
| # --- New Request Model for Similar By Vector Endpoint --- | |
| class SimilarByVectorRequest(BaseModel): | |
| vector: List[float] | |
| n: int = 10 | |
| def init_router(study_service: StudyService): | |
| router = APIRouter(prefix="/api/study") | |
| async def analyze_concept(request: ConceptRequest): | |
| try: | |
| return await study_service.analyze_concept( | |
| request.positive_words, | |
| request.negative_words, | |
| request.n_results | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error analyzing concept: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def analyze_analogy(request: AnalogyRequest): | |
| try: | |
| return await study_service.analyze_analogy( | |
| request.word1, | |
| request.word2, | |
| request.word3, | |
| request.n_results | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error analyzing analogy: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def analyze_semantic_field(request: SemanticFieldRequest): | |
| try: | |
| return await study_service.analyze_semantic_field( | |
| request.words, | |
| request.n_neighbors | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error analyzing semantic field: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def analyze_neighborhood(request: NeighborhoodRequest): | |
| """Analyze word neighborhood with detailed semantic information""" | |
| try: | |
| return await study_service.analyze_word_neighborhood( | |
| request.word, | |
| request.n_neighbors | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error analyzing word neighborhood: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def get_visualization_data(request: VisualizationRequest): | |
| """ | |
| Retrieve the words along with their raw vector representations. | |
| The external visualization service will receive these vectors and | |
| perform the projection (e.g. to 3D) as needed. | |
| """ | |
| try: | |
| return await study_service.get_word_vectors(request.words) | |
| except Exception as e: | |
| logger.error(f"Error retrieving visualization data: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def get_phrase_vector(request: PhraseRequest): | |
| """ | |
| Compute and return the averaged embedding for a list of words (phrase). | |
| """ | |
| try: | |
| vector = await study_service.get_phrase_vector(request.words) | |
| if vector is None: | |
| raise HTTPException(status_code=404, detail="No valid vectors found for given words.") | |
| return {"phrase_vector": vector} | |
| except Exception as e: | |
| logger.error(f"Error computing phrase vector: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def cluster_words(request: ClusterRequest): | |
| """ | |
| Cluster the embeddings of the given words using K-Means. | |
| Returns clusters and centroids. | |
| """ | |
| try: | |
| return await study_service.cluster_words(request.words, request.n_clusters) | |
| except Exception as e: | |
| logger.error(f"Error clustering words: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def find_outlier(request: OutlierRequest): | |
| """ | |
| Identify the outlier (the word least similar to the rest) in a list of words. | |
| """ | |
| try: | |
| return await study_service.find_outlier(request.words) | |
| except Exception as e: | |
| logger.error(f"Error finding outlier: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def distance_distribution(request: DistributionRequest): | |
| """ | |
| Compute the distribution of cosine similarities between the target word and a sample of words. | |
| """ | |
| try: | |
| return await study_service.distance_distribution(request.word, request.sample_size) | |
| except Exception as e: | |
| logger.error(f"Error computing distance distribution: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def interpolate_words(request: InterpolationRequest): | |
| """ | |
| Generate a series of intermediate vectors between two words and retrieve the closest word for each step. | |
| """ | |
| try: | |
| return await study_service.interpolate_words(request.word1, request.word2, request.steps) | |
| except Exception as e: | |
| logger.error(f"Error interpolating words: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def combine_word_vectors(request: CombineRequest): | |
| """ | |
| Combine word vectors given weighted positive and negative contributions. | |
| Returns the combined normalized vector. | |
| """ | |
| try: | |
| combined_vector = await study_service.combine_word_vectors( | |
| positive=[(item.word, item.weight) for item in request.positive], | |
| negative=[(item.word, item.weight) for item in request.negative] | |
| ) | |
| if combined_vector is None: | |
| raise HTTPException(status_code=404, detail="Could not compute combined vector.") | |
| return {"combined_vector": combined_vector} | |
| except Exception as e: | |
| logger.error(f"Error combining word vectors: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| # --- New Endpoint: Similar By Vector --- | |
| async def similar_by_vector_endpoint(request: SimilarByVectorRequest): | |
| """ | |
| Given a vector (list of floats) and a number n, return the n words most similar to that vector. | |
| """ | |
| try: | |
| import numpy as np | |
| vector = np.array(request.vector) | |
| similar = await study_service.word_service.get_similar_by_vector(vector, n=request.n) | |
| return {"similar_words": similar} | |
| except Exception as e: | |
| logger.error(f"Error computing similar words by vector: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| return router | |