Jiahuita
Attempt to resolve deployment issue
b2de734
raw
history blame
1.3 kB
from transformers import Pipeline
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
import json
import os
def load_tokenizer(tokenizer_path):
with open(tokenizer_path, 'r') as f:
return json.load(f)
class NewsClassificationPipeline(Pipeline):
def __init__(self, model=None, tokenizer=None, **kwargs):
super().__init__(**kwargs)
model_path = os.path.join(os.path.dirname(__file__), 'news_classifier.h5')
self.model = tf.keras.models.load_model(model_path)
tokenizer_path = os.path.join(os.path.dirname(__file__), 'tokenizer.json')
self.tokenizer_config = load_tokenizer(tokenizer_path)
def __call__(self, texts, **kwargs):
if isinstance(texts, str):
texts = [texts]
sequences = self.tokenizer.texts_to_sequences(texts)
padded = pad_sequences(sequences, maxlen=128)
predictions = self.model.predict(padded)
results = []
for pred in predictions:
label = "foxnews" if pred[0] > 0.5 else "nbc"
score = float(pred[0] if label == "foxnews" else 1 - pred[0])
results.append({"label": label, "score": score})
return results[0] if isinstance(texts, str) else results