""" Utility functions for HuggingFace Enabling Sessions Spaces app """ import torch from transformers import ( pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM ) import numpy as np from functools import lru_cache import config # Lazy loading for heavy dependencies _sbert_model = None _qa_model = None _qa_tokenizer = None _summarization_model = None _summarization_tokenizer = None def get_sbert_model(): """Lazy load Sentence-BERT model only when needed.""" global _sbert_model if _sbert_model is None: from sentence_transformers import SentenceTransformer _sbert_model = SentenceTransformer(config.EMBEDDINGS_MODEL, device="cpu") return _sbert_model @lru_cache(maxsize=10) def load_pipeline(task_type: str): """Load and cache a pipeline for the given task.""" try: device = -1 # Use CPU (safer for Spaces) if task_type == "sentiment": return pipeline("sentiment-analysis", model=config.SENTIMENT_MODEL, device=device) elif task_type == "ner": try: return pipeline("ner", model=config.NER_MODEL, device=device, aggregation_strategy="simple") except Exception: # Fallback to another public NER model if primary ID fails. fallback_ner_model = "dbmdz/bert-large-cased-finetuned-conll03-english" return pipeline("ner", model=fallback_ner_model, device=device, aggregation_strategy="simple") elif task_type == "summarization": # `summarization` alias is not present in some transformers builds. return pipeline("text2text-generation", model=config.SUMMARIZATION_MODEL, device=device) else: raise ValueError(f"Unknown task type: {task_type}") except Exception as e: raise Exception(f"Error loading {task_type} pipeline: {str(e)}") def get_qa_model(): """Lazy load QA model and tokenizer.""" global _qa_model, _qa_tokenizer if _qa_model is None: _qa_tokenizer = AutoTokenizer.from_pretrained(config.QA_MODEL) _qa_model = AutoModelForQuestionAnswering.from_pretrained(config.QA_MODEL) _qa_model.eval() return _qa_model, _qa_tokenizer def get_summarization_model(): """Lazy load Summarization model and tokenizer.""" global _summarization_model, _summarization_tokenizer if _summarization_model is None: _summarization_tokenizer = AutoTokenizer.from_pretrained(config.SUMMARIZATION_MODEL) _summarization_model = AutoModelForSeq2SeqLM.from_pretrained(config.SUMMARIZATION_MODEL) _summarization_model.eval() return _summarization_model, _summarization_tokenizer def run_sentiment_analysis(text: str): """Run sentiment analysis on text.""" pipe = load_pipeline("sentiment") result = pipe(text[:512]) # Truncate to avoid token limit return result[0] if result else {"label": "Unknown", "score": 0} def run_ner(text: str): """Run Named Entity Recognition on text.""" try: pipe = load_pipeline("ner") result = pipe(text[:512]) return result except Exception as e: return [{"word": "", "entity_group": "ERROR", "score": 0.0, "error": str(e)}] def run_qa(context: str, question: str): """Run question answering on context using direct model inference.""" try: model, tokenizer = get_qa_model() inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) answer_start_idx = outputs.start_logits.argmax(dim=1).item() answer_end_idx = outputs.end_logits.argmax(dim=1).item() + 1 answer = tokenizer.convert_tokens_to_string( tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start_idx:answer_end_idx]) ) score = (outputs.start_logits.max().item() + outputs.end_logits.max().item()) / 2 return { "answer": answer.strip(), "score": float(score), "start": int(answer_start_idx), "end": int(answer_end_idx) } except Exception as e: return {"error": str(e), "answer": "Unable to answer", "score": 0} def run_summarization(text: str): """Generate summary of text using direct model inference.""" try: model, tokenizer = get_summarization_model() inputs = tokenizer(text[:1024], return_tensors="pt", max_length=1024, truncation=True) with torch.no_grad(): summary_ids = model.generate( inputs["input_ids"], max_length=150, min_length=30, num_beams=4, length_penalty=2.0, early_stopping=True, forced_bos_token_id=0, ) summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0] return summary.strip() except Exception as e: return f"Error: {str(e)}" def compute_similarity(text1: str, text2: str): """Compute semantic similarity between two texts.""" try: from sentence_transformers import util model = get_sbert_model() embeddings = model.encode([text1, text2], convert_to_tensor=True) similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1]) return float(similarity.item()) except Exception as e: return f"Error: {str(e)}" def tokenize_text(text: str, model_name: str = config.SENTIMENT_MODEL): """Tokenize text and show tokens with IDs.""" try: tokenizer = AutoTokenizer.from_pretrained(model_name) encoding = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) token_ids = encoding["input_ids"][0].tolist() attention_mask = encoding["attention_mask"][0].tolist() result = { "tokens": tokens, "token_ids": token_ids, "attention_mask": attention_mask, "num_tokens": len(tokens), } return result except Exception as e: return {"error": str(e)} def format_tokenizer_output(tokenization_result): """Format tokenization result for display.""" if "error" in tokenization_result: return f"Error: {tokenization_result['error']}" tokens = tokenization_result["tokens"] token_ids = tokenization_result["token_ids"] output = f"**Total Tokens:** {tokenization_result['num_tokens']}\n\n" output += "| Token | Token ID | Attention Mask |\n" output += "|-------|----------|----------------|\n" for token, tid, attn in zip( tokens, token_ids, tokenization_result["attention_mask"] ): output += f"| {token} | {tid} | {attn} |\n" return output def format_ner_output(ner_results): """Format NER results for display.""" if not ner_results: return "No entities found" output = "| Entity | Type | Score |\n" output += "|--------|------|-------|\n" for result in ner_results: word = result.get("word", "") entity_type = result.get("entity_group", result.get("entity", "")) score = result.get("score", 0) output += f"| {word} | {entity_type} | {score:.4f} |\n" return output