from typing import Dict, List, Any import torch from transformers import AutoTokenizer from gliclass import GLiClassModel, ZeroShotClassificationPipeline import gradio as gr import sys from logger import logging from exception import CustomExceptionHandling # Device Setup device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Cache for loaded pipelines # Key: (model_name, classification_type) pipeline_cache = {} def get_pipeline( model_name: str, classification_type: str, device: torch.device ) -> ZeroShotClassificationPipeline: """Get pipeline from cache or load it if not cached.""" cache_key = (model_name, classification_type) if cache_key not in pipeline_cache: print(f"Loading model {model_name} for {classification_type}...") try: # Load model and tokenizer model = GLiClassModel.from_pretrained(model_name) if "modern" in model_name: tokenizer = AutoTokenizer.from_pretrained( model_name, add_prefix_space=True ) else: tokenizer = AutoTokenizer.from_pretrained(model_name) # Create pipeline pipeline = ZeroShotClassificationPipeline( model, tokenizer, classification_type=classification_type, device=device ) # Cache pipeline pipeline_cache[cache_key] = pipeline except Exception as e: print(f"Error loading model {model_name}: {e}") raise CustomExceptionHandling(e, sys) from e # Return cached pipeline return pipeline_cache[cache_key] def classification( text: str, labels: str, threshold: float, multi_label: bool = False, model_name: str = "knowledgator/gliclass-base-v3.0", ) -> Dict[str, Any]: """Classify the given text into the given labels.""" try: if not text: raise gr.Error("No text provided") if not labels: raise gr.Error("No labels provided") # Parse and clean labels label_list: List[str] = [l.strip() for l in labels.split(",") if l.strip()] # Determine classification type classification_type = "multi-label" if multi_label else "single-label" # Get pipeline pipeline = get_pipeline(model_name, classification_type, device) # Classify text results = pipeline(text, label_list, threshold=threshold)[0] # because we have one text # Parse results into a dictionary predictions = {} for predict in results: predictions[predict["label"]] = predict["score"] return predictions except Exception as e: print(f"Error classifying text: {e}") raise CustomExceptionHandling(e, sys) from e