Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Any | |
| import torch | |
| from transformers import AutoTokenizer | |
| from gliclass import GLiClassModel, ZeroShotClassificationPipeline | |
| import gradio as gr | |
| import sys | |
| from logger import logging | |
| from exception import CustomExceptionHandling | |
| # Device Setup | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| # Cache for loaded pipelines | |
| # Key: (model_name, classification_type) | |
| pipeline_cache = {} | |
| def get_pipeline( | |
| model_name: str, classification_type: str, device: torch.device | |
| ) -> ZeroShotClassificationPipeline: | |
| """Get pipeline from cache or load it if not cached.""" | |
| cache_key = (model_name, classification_type) | |
| if cache_key not in pipeline_cache: | |
| print(f"Loading model {model_name} for {classification_type}...") | |
| try: | |
| # Load model and tokenizer | |
| model = GLiClassModel.from_pretrained(model_name) | |
| if "modern" in model_name: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, add_prefix_space=True | |
| ) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Create pipeline | |
| pipeline = ZeroShotClassificationPipeline( | |
| model, tokenizer, classification_type=classification_type, device=device | |
| ) | |
| # Cache pipeline | |
| pipeline_cache[cache_key] = pipeline | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {e}") | |
| raise CustomExceptionHandling(e, sys) from e | |
| # Return cached pipeline | |
| return pipeline_cache[cache_key] | |
| def classification( | |
| text: str, | |
| labels: str, | |
| threshold: float, | |
| multi_label: bool = False, | |
| model_name: str = "knowledgator/gliclass-base-v3.0", | |
| ) -> Dict[str, Any]: | |
| """Classify the given text into the given labels.""" | |
| try: | |
| if not text: | |
| raise gr.Error("No text provided") | |
| if not labels: | |
| raise gr.Error("No labels provided") | |
| # Parse and clean labels | |
| label_list: List[str] = [l.strip() for l in labels.split(",") if l.strip()] | |
| # Determine classification type | |
| classification_type = "multi-label" if multi_label else "single-label" | |
| # Get pipeline | |
| pipeline = get_pipeline(model_name, classification_type, device) | |
| # Classify text | |
| results = pipeline(text, label_list, threshold=threshold)[0] # because we have one text | |
| # Parse results into a dictionary | |
| predictions = {} | |
| for predict in results: | |
| predictions[predict["label"]] = predict["score"] | |
| return predictions | |
| except Exception as e: | |
| print(f"Error classifying text: {e}") | |
| raise CustomExceptionHandling(e, sys) from e | |