DeepShcool / src /pipeline.py
AndreiBar's picture
Update src/pipeline.py
41e765f verified
Raw
History Blame Contribute Delete
964 Bytes
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
)
import os
CACHE_ROOT = os.environ.get("HF_CACHE_DIR", "/tmp/hf-cache")
os.makedirs(CACHE_ROOT, exist_ok=True)
os.environ["HF_HOME"] = CACHE_ROOT
os.environ["TRANSFORMERS_CACHE"] = os.path.join(CACHE_ROOT, "transformers")
os.environ["HF_HUB_CACHE"] = os.path.join(CACHE_ROOT, "hub")
def predict_sentiment(text: str, task: str, model_name: str):
cache_dir = os.environ.get("TRANSFORMERS_CACHE")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir)
pipe = pipeline(task = task,
model = model,
tokenizer=tokenizer,
)
pipeline_result = pipe(text) # замените на выход пайплайна!
return pipeline_result