Spaces:
Running
Running
| """ | |
| Embedding Inference API | |
| Supports JobBERT v2, Jina AI, and Voyage AI embeddings | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional | |
| from sentence_transformers import SentenceTransformer | |
| import os | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="Embedding Inference API", | |
| description="Generate embeddings using JobBERT v2/v3, Jina AI, or Voyage AI", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODELS = {} | |
| VOYAGE_API_KEY = os.environ.get('VOYAGE_API_KEY', '') | |
| voyage_client = None | |
| if VOYAGE_API_KEY: | |
| try: | |
| import voyageai | |
| voyage_client = voyageai.Client(api_key=VOYAGE_API_KEY) | |
| logger.info("✓ Voyage AI client initialized") | |
| except ImportError: | |
| logger.warning("⚠️ voyageai package not installed") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Voyage AI initialization failed: {e}") | |
| def load_models(): | |
| """Load embedding models on startup""" | |
| try: | |
| logger.info("Loading JobBERT-v2...") | |
| MODELS['jobbertv2'] = SentenceTransformer('TechWolf/JobBERT-v2') | |
| logger.info("✓ JobBERT-v2 loaded") | |
| logger.info("Loading JobBERT-v3...") | |
| MODELS['jobbertv3'] = SentenceTransformer('TechWolf/JobBERT-v3') | |
| logger.info("✓ JobBERT-v3 loaded") | |
| logger.info("Loading Jina AI embeddings-v3...") | |
| MODELS['jina'] = SentenceTransformer('jinaai/jina-embeddings-v3', trust_remote_code=True) | |
| logger.info("✓ Jina AI v3 loaded") | |
| logger.info("All models loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Error loading models: {e}") | |
| raise | |
| async def startup_event(): | |
| load_models() | |
| class EmbeddingRequest(BaseModel): | |
| texts: List[str] = Field(..., description="List of texts to embed", min_items=1) | |
| model: str = Field(..., description="Model to use: 'jobbertv2', 'jobbertv3', 'jina', or 'voyage'") | |
| task: Optional[str] = Field(None, description="Task type for Jina AI: 'retrieval.query', 'retrieval.passage', 'text-matching', etc.") | |
| input_type: Optional[str] = Field(None, description="Input type for Voyage AI: 'document' or 'query'") | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "texts": ["Software Engineer", "Data Scientist"], | |
| "model": "jobbertv3", | |
| "task": "text-matching" | |
| } | |
| } | |
| class EmbeddingResponse(BaseModel): | |
| embeddings: List[List[float]] = Field(..., description="List of embedding vectors") | |
| model: str = Field(..., description="Model used") | |
| dimension: int = Field(..., description="Embedding dimension") | |
| num_texts: int = Field(..., description="Number of texts processed") | |
| class HealthResponse(BaseModel): | |
| status: str | |
| models_loaded: List[str] | |
| voyage_available: bool | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Embedding Inference API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/health": "Health check and available models", | |
| "/embed": "Generate embeddings (POST)", | |
| "/docs": "API documentation" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| models_loaded = list(MODELS.keys()) | |
| return { | |
| "status": "healthy", | |
| "models_loaded": models_loaded, | |
| "voyage_available": voyage_client is not None | |
| } | |
| async def create_embeddings(request: EmbeddingRequest): | |
| """ | |
| Generate embeddings for input texts | |
| **Models:** | |
| - `jobbertv2`: JobBERT-v2 (768-dim, job-specific) | |
| - `jobbertv3`: JobBERT-v3 (768-dim, job-specific, improved performance) | |
| - `jina`: Jina AI embeddings-v3 (1024-dim, general purpose, supports task types) | |
| - `voyage`: Voyage AI (1024-dim, requires API key) | |
| **Jina AI Tasks:** | |
| - `retrieval.query`: For search queries | |
| - `retrieval.passage`: For documents/passages | |
| - `text-matching`: For similarity matching (default) | |
| - `classification`: For classification tasks | |
| - `separation`: For clustering | |
| **Voyage AI Input Types:** | |
| - `document`: For documents/passages | |
| - `query`: For search queries | |
| """ | |
| model_name = request.model.lower() | |
| if model_name == "voyage": | |
| if not voyage_client: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Voyage AI not available. Set VOYAGE_API_KEY environment variable." | |
| ) | |
| try: | |
| input_type = request.input_type or "document" | |
| result = voyage_client.embed( | |
| texts=request.texts, | |
| model="voyage-3", | |
| input_type=input_type | |
| ) | |
| embeddings = result.embeddings | |
| dimension = len(embeddings[0]) if embeddings else 0 | |
| return EmbeddingResponse( | |
| embeddings=embeddings, | |
| model="voyage-3", | |
| dimension=dimension, | |
| num_texts=len(request.texts) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Voyage AI error: {str(e)}") | |
| elif model_name in MODELS: | |
| try: | |
| model = MODELS[model_name] | |
| if model_name == "jina" and request.task: | |
| embeddings = model.encode( | |
| request.texts, | |
| task=request.task, | |
| convert_to_numpy=True | |
| ) | |
| else: | |
| embeddings = model.encode( | |
| request.texts, | |
| convert_to_numpy=True | |
| ) | |
| embeddings_list = embeddings.tolist() | |
| dimension = len(embeddings_list[0]) if embeddings_list else 0 | |
| return EmbeddingResponse( | |
| embeddings=embeddings_list, | |
| model=model_name, | |
| dimension=dimension, | |
| num_texts=len(request.texts) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model error: {str(e)}") | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model '{model_name}'. Choose from: jobbertv2, jobbertv3, jina, voyage" | |
| ) | |
| async def list_models(): | |
| """List available models and their specifications""" | |
| models_info = { | |
| "jobbertv2": { | |
| "name": "TechWolf/JobBERT-v2", | |
| "dimension": 768, | |
| "description": "Job-specific BERT model fine-tuned on job titles", | |
| "max_tokens": 512, | |
| "available": "jobbertv2" in MODELS | |
| }, | |
| "jobbertv3": { | |
| "name": "TechWolf/JobBERT-v3", | |
| "dimension": 768, | |
| "description": "Latest JobBERT model with improved performance", | |
| "max_tokens": 512, | |
| "available": "jobbertv3" in MODELS | |
| }, | |
| "jina": { | |
| "name": "jinaai/jina-embeddings-v3", | |
| "dimension": 1024, | |
| "description": "General-purpose embeddings with long context support", | |
| "max_tokens": 8192, | |
| "available": "jina" in MODELS, | |
| "tasks": ["retrieval.query", "retrieval.passage", "text-matching", "classification", "separation"] | |
| }, | |
| "voyage": { | |
| "name": "voyage-3", | |
| "dimension": 1024, | |
| "description": "State-of-the-art embeddings (requires API key)", | |
| "max_tokens": 32000, | |
| "available": voyage_client is not None, | |
| "input_types": ["document", "query"] | |
| } | |
| } | |
| return models_info | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |