Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # --- MODEL LOADING (Runs once at startup) --- | |
| MODEL_NAME = "naver/splade-cocondenser-ensembledistil" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on: {device}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).to(device) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| # Input Schema | |
| class TextRequest(BaseModel): | |
| text: str | |
| def home(): | |
| return {"status": "SPLADE API is running", "device": device} | |
| def get_splade_vector(request: TextRequest): | |
| text = request.text | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Input text cannot be empty.") | |
| # Tokenize and move to GPU | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device) | |
| # Inference | |
| logits = model(**inputs).logits # [1, seq_len, vocab_size] | |
| # SPLADE Logic: log(1 + ReLU(logits)) + max-pooling | |
| term_scores = torch.log1p(torch.relu(logits)) | |
| term_importance = term_scores.max(dim=1).values.squeeze(0) # [vocab_size] | |
| # Extract non-zero values | |
| nz = torch.nonzero(term_importance, as_tuple=True)[0] | |
| weights = term_importance[nz] | |
| # Convert to standard Python lists (CPU) | |
| indices = nz.cpu().tolist() | |
| values = weights.cpu().float().tolist() | |
| return {"indices": indices, "values": values} |