Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Any | |
| from openai import OpenAI | |
| from rag_demo.rag.base.query import Query | |
| from rag_demo.rag.base.template_factory import RAGStep | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from loguru import logger | |
| import torch | |
| model_name = ( | |
| "AdrienB134/greetings-classifier" # Model trained on English greetings only | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| class QueryClassifier(RAGStep): | |
| def generate(self, query: Query) -> Any: | |
| if self._mock: | |
| return "Sources_needed" | |
| with torch.no_grad(): | |
| inputs = tokenizer(query.content, return_tensors="pt") | |
| logits = model(**inputs).logits | |
| predictions = logits.argmax() | |
| return model.config.id2label[predictions.item()] | |