Metaphor_Scoring_Model / Interactive.py
pa90's picture
Add "interactive.py"
2595ee3 verified
#!/usr/bin/env python3
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class MetaphorScorer:
def __init__(self, model_path='.'):
"""
Initialize the metaphor scorer.
Args:
model_path: Path or Hugging Face repo ID.
Default '.' uses current directory (where model files are)
Or use 'your-username/Metaphor_Scoring_Model' to load from Hub
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Loading model from: {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
print(f"Model loaded on {self.device}")
def score_sentence(self, sentence):
"""
Score a sentence for metaphorical novelty.
Args:
sentence: Input sentence to score
Returns:
score: Novelty score (1-4)
confidence: Model confidence (0-1)
"""
inputs = self.tokenizer(
sentence,
return_tensors='pt',
max_length=256,
truncation=True,
padding='max_length'
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
score = predicted_class + 1
probabilities = torch.softmax(logits, dim=-1)
confidence = probabilities[0][predicted_class].item()
return score, confidence
def main():
# Load model (will use current directory by default)
scorer = MetaphorScorer()
print("\n=== Metaphorical Sentence Scorer ===")
print("Enter metaphorical sentences to get novelty scores (1-4)")
print("Higher scores = Higher metaphorical novelty")
print("Type 'quit' to exit\n")
while True:
sentence = input("Enter sentence: ").strip()
if sentence.lower() in ['quit', 'exit', 'q']:
print("Goodbye!")
break
if not sentence:
continue
score, confidence = scorer.score_sentence(sentence)
print(f"Score: {score}/4 (confidence: {confidence:.3f})\n")
if __name__ == "__main__":
main()