import torch from transformers import AutoTokenizer, AutoModelForMaskedLM from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn # Initialize FastAPI app = FastAPI() # --- MODEL LOADING (Runs once at startup) --- MODEL_NAME = "naver/splade-cocondenser-ensembledistil" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on: {device}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).to(device) model.eval() print("Model loaded successfully.") # Input Schema class TextRequest(BaseModel): text: str @app.get("/") def home(): return {"status": "SPLADE API is running", "device": device} @app.post("/splade") @torch.no_grad() def get_splade_vector(request: TextRequest): text = request.text if not text.strip(): raise HTTPException(status_code=400, detail="Input text cannot be empty.") # Tokenize and move to GPU inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) # Inference logits = model(**inputs).logits # [1, seq_len, vocab_size] # SPLADE Logic: log(1 + ReLU(logits)) + max-pooling term_scores = torch.log1p(torch.relu(logits)) term_importance = term_scores.max(dim=1).values.squeeze(0) # [vocab_size] # Extract non-zero values nz = torch.nonzero(term_importance, as_tuple=True)[0] weights = term_importance[nz] # Convert to standard Python lists (CPU) indices = nz.cpu().tolist() values = weights.cpu().float().tolist() return {"indices": indices, "values": values}