Shreshth2002's picture
Upload folder using huggingface_hub
9cbb56b verified
"""
Inference pipeline for DistilBERT sentiment analysis
File: infer.py (improved version)
"""
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Global variables to cache model and tokenizer
_model = None
_tokenizer = None
def load_trained_model(model_path="./model"):
"""Load saved model and tokenizer (cached)"""
global _model, _tokenizer
# Check if model exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"No model found at {model_path}. Please train the model first.")
# Return cached model if already loaded
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
print(f"Loading model from {model_path}...")
_tokenizer = AutoTokenizer.from_pretrained(model_path)
_model = AutoModelForSequenceClassification.from_pretrained(model_path)
print("Model loaded successfully!")
return _model, _tokenizer
def predict_sentiment(text, model, tokenizer, max_length=256):
"""
Predict sentiment for a single text
Args:
text: Input text string
model: Loaded model
tokenizer: Loaded tokenizer
max_length: Max sequence length
Returns:
Tuple of (predicted_label, confidence_score)
"""
# Tokenize input
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length
)
# Get prediction
model.eval()
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1).item()
confidence = predictions[0][predicted_class].item()
# Convert to readable format
label = "Positive" if predicted_class == 1 else "Negative"
return label, confidence
def predict(text, model_path="./model", max_length=256):
"""
Simple prediction function for new text
Args:
text: Input text string
model_path: Path to saved model
max_length: Max sequence length
Returns:
String: "positive" or "negative"
"""
try:
# Load model and tokenizer (cached)
model, tokenizer = load_trained_model(model_path)
# Tokenize input
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length
)
# Get prediction
model.eval()
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=-1).item()
return "positive" if predicted_class == 1 else "negative"
except FileNotFoundError as e:
return f"Error: {str(e)}"
except Exception as e:
return f"Prediction error: {str(e)}"