from services.study_service import StudyService from fastapi import APIRouter, HTTPException from loguru import logger from typing import List from pydantic import BaseModel router = APIRouter(prefix="/api/study") # --- Existing Request Models --- class NeighborhoodRequest(BaseModel): word: str n_neighbors: int = 20 class VisualizationRequest(BaseModel): words: List[str] class ConceptRequest(BaseModel): positive_words: List[str] negative_words: List[str] = [] n_results: int = 10 class AnalogyRequest(BaseModel): word1: str word2: str word3: str n_results: int = 10 class SemanticFieldRequest(BaseModel): words: List[str] n_neighbors: int = 5 # --- New Request Models for additional study operations --- class PhraseRequest(BaseModel): words: List[str] class ClusterRequest(BaseModel): words: List[str] n_clusters: int = 3 class OutlierRequest(BaseModel): words: List[str] class DistributionRequest(BaseModel): word: str sample_size: int = 1000 class InterpolationRequest(BaseModel): word1: str word2: str steps: int = 5 class WeightedWord(BaseModel): word: str weight: float class CombineRequest(BaseModel): positive: List[WeightedWord] negative: List[WeightedWord] = [] # --- New Request Model for Similar By Vector Endpoint --- class SimilarByVectorRequest(BaseModel): vector: List[float] n: int = 10 def init_router(study_service: StudyService): router = APIRouter(prefix="/api/study") @router.post("/concept") async def analyze_concept(request: ConceptRequest): try: return await study_service.analyze_concept( request.positive_words, request.negative_words, request.n_results ) except Exception as e: logger.error(f"Error analyzing concept: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/analogy") async def analyze_analogy(request: AnalogyRequest): try: return await study_service.analyze_analogy( request.word1, request.word2, request.word3, request.n_results ) except Exception as e: logger.error(f"Error analyzing analogy: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/semantic-field") async def analyze_semantic_field(request: SemanticFieldRequest): try: return await study_service.analyze_semantic_field( request.words, request.n_neighbors ) except Exception as e: logger.error(f"Error analyzing semantic field: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/neighborhood") async def analyze_neighborhood(request: NeighborhoodRequest): """Analyze word neighborhood with detailed semantic information""" try: return await study_service.analyze_word_neighborhood( request.word, request.n_neighbors ) except Exception as e: logger.error(f"Error analyzing word neighborhood: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/visualization") async def get_visualization_data(request: VisualizationRequest): """ Retrieve the words along with their raw vector representations. The external visualization service will receive these vectors and perform the projection (e.g. to 3D) as needed. """ try: return await study_service.get_word_vectors(request.words) except Exception as e: logger.error(f"Error retrieving visualization data: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/phrase") async def get_phrase_vector(request: PhraseRequest): """ Compute and return the averaged embedding for a list of words (phrase). """ try: vector = await study_service.get_phrase_vector(request.words) if vector is None: raise HTTPException(status_code=404, detail="No valid vectors found for given words.") return {"phrase_vector": vector} except Exception as e: logger.error(f"Error computing phrase vector: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/cluster") async def cluster_words(request: ClusterRequest): """ Cluster the embeddings of the given words using K-Means. Returns clusters and centroids. """ try: return await study_service.cluster_words(request.words, request.n_clusters) except Exception as e: logger.error(f"Error clustering words: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/outlier") async def find_outlier(request: OutlierRequest): """ Identify the outlier (the word least similar to the rest) in a list of words. """ try: return await study_service.find_outlier(request.words) except Exception as e: logger.error(f"Error finding outlier: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/distribution") async def distance_distribution(request: DistributionRequest): """ Compute the distribution of cosine similarities between the target word and a sample of words. """ try: return await study_service.distance_distribution(request.word, request.sample_size) except Exception as e: logger.error(f"Error computing distance distribution: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/interpolate") async def interpolate_words(request: InterpolationRequest): """ Generate a series of intermediate vectors between two words and retrieve the closest word for each step. """ try: return await study_service.interpolate_words(request.word1, request.word2, request.steps) except Exception as e: logger.error(f"Error interpolating words: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @router.post("/combine") async def combine_word_vectors(request: CombineRequest): """ Combine word vectors given weighted positive and negative contributions. Returns the combined normalized vector. """ try: combined_vector = await study_service.combine_word_vectors( positive=[(item.word, item.weight) for item in request.positive], negative=[(item.word, item.weight) for item in request.negative] ) if combined_vector is None: raise HTTPException(status_code=404, detail="Could not compute combined vector.") return {"combined_vector": combined_vector} except Exception as e: logger.error(f"Error combining word vectors: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") # --- New Endpoint: Similar By Vector --- @router.post("/similar-by-vector") async def similar_by_vector_endpoint(request: SimilarByVectorRequest): """ Given a vector (list of floats) and a number n, return the n words most similar to that vector. """ try: import numpy as np vector = np.array(request.vector) similar = await study_service.word_service.get_similar_by_vector(vector, n=request.n) return {"similar_words": similar} except Exception as e: logger.error(f"Error computing similar words by vector: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") return router