Sai Pranav Reddy
Clean lightweight deployment
968e24d
# src/summarization/ranker.py
import torch
from transformers import AutoTokenizer
from src.summarization.model import SentenceRanker
class ImportanceRanker:
def __init__(self, model_dir, base_model="nlpaueb/legal-bert-base-uncased"):
# Load the tokenizer from the base model
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
# Initialize the custom architecture with base model
self.model = SentenceRanker(base_model)
# Load fine-tuned weights
import os
from safetensors.torch import load_file
weights_path = os.path.join(model_dir, "model.safetensors")
if os.path.exists(weights_path):
state_dict = load_file(weights_path)
self.model.load_state_dict(state_dict)
else:
print(f"Warning: Could not find {weights_path}")
self.model.eval()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def score(self, sentences):
inputs = self.tokenizer(
sentences,
truncation=True,
padding=True,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs)["logits"]
return logits.sigmoid().cpu().tolist()