GLiClass-Playground / classifier.py
sitammeur's picture
Update classifier.py
67a0d7c verified
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
import gradio as gr
import sys
from logger import logging
from exception import CustomExceptionHandling
# Device Setup
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Cache for loaded pipelines
# Key: (model_name, classification_type)
pipeline_cache = {}
def get_pipeline(
model_name: str, classification_type: str, device: torch.device
) -> ZeroShotClassificationPipeline:
"""Get pipeline from cache or load it if not cached."""
cache_key = (model_name, classification_type)
if cache_key not in pipeline_cache:
print(f"Loading model {model_name} for {classification_type}...")
try:
# Load model and tokenizer
model = GLiClassModel.from_pretrained(model_name)
if "modern" in model_name:
tokenizer = AutoTokenizer.from_pretrained(
model_name, add_prefix_space=True
)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create pipeline
pipeline = ZeroShotClassificationPipeline(
model, tokenizer, classification_type=classification_type, device=device
)
# Cache pipeline
pipeline_cache[cache_key] = pipeline
except Exception as e:
print(f"Error loading model {model_name}: {e}")
raise CustomExceptionHandling(e, sys) from e
# Return cached pipeline
return pipeline_cache[cache_key]
def classification(
text: str,
labels: str,
threshold: float,
multi_label: bool = False,
model_name: str = "knowledgator/gliclass-base-v3.0",
) -> Dict[str, Any]:
"""Classify the given text into the given labels."""
try:
if not text:
raise gr.Error("No text provided")
if not labels:
raise gr.Error("No labels provided")
# Parse and clean labels
label_list: List[str] = [l.strip() for l in labels.split(",") if l.strip()]
# Determine classification type
classification_type = "multi-label" if multi_label else "single-label"
# Get pipeline
pipeline = get_pipeline(model_name, classification_type, device)
# Classify text
results = pipeline(text, label_list, threshold=threshold)[0] # because we have one text
# Parse results into a dictionary
predictions = {}
for predict in results:
predictions[predict["label"]] = predict["score"]
return predictions
except Exception as e:
print(f"Error classifying text: {e}")
raise CustomExceptionHandling(e, sys) from e