import gc import logging import re import time from collections import Counter from typing import List, Optional import torch from pydantic import BaseModel, ValidationError from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, MarianTokenizer, MarianMTModel, pipeline, PreTrainedTokenizer, AutoModelForCausalLM, BitsAndBytesConfig ) # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) B_INST, E_INST = "[INST] ", " [/INST]" MAX_SUMMARY_TOKENS = 15900 MAX_TRANSLATION_TOKENS = 256 MAX_SENTIMENT_TOKENS = 256 class InferenceConfig(BaseModel): sentiment: bool = False summarize: bool = False prompt: Optional[str] = None def split_by_token_limit_with_speakers(text: str, max_tokens: int, tokenizer: PreTrainedTokenizer) -> List[str]: """ Split transcript text into chunks with max token length, preserving speaker turns. """ # Regex to extract speaker turns speaker_pattern = re.compile(r"(Agent:|Customer:)") parts = speaker_pattern.split(text) if not parts[0].strip(): parts = parts[1:] # remove empty prefix if present # Pair up speaker labels with their content turns = [(parts[i], parts[i + 1].strip()) for i in range(0, len(parts), 2)] chunks = [] current_chunk = "" current_token_count = 0 for speaker, utterance in turns: turn_text = f"{speaker} {utterance}" tokenized = tokenizer(turn_text, add_special_tokens=False, return_attention_mask=True) turn_tokens = len(tokenized["input_ids"]) # If turn fits, add to current chunk if current_token_count + turn_tokens <= max_tokens: if current_chunk: current_chunk += " " + turn_text else: current_chunk = turn_text current_token_count += turn_tokens else: # If turn itself is too long, split by words if turn_tokens > max_tokens: words = utterance.split() partial = [] for word in words: candidate = f"{speaker} {' '.join(partial + [word])}" tokens = len(tokenizer(candidate, add_special_tokens=False)["input_ids"]) if tokens > max_tokens: # finalize current partial if partial: chunks.append(f"{speaker} {' '.join(partial)}") partial = [word] else: partial.append(word) if partial: chunks.append(f"{speaker} {' '.join(partial)}") current_chunk = "" current_token_count = 0 else: # finalize current chunk and start a new one if current_chunk: chunks.append(current_chunk) current_chunk = turn_text current_token_count = turn_tokens if current_chunk: chunks.append(current_chunk) return chunks def translate_chunks(chunks: list[str], tokenizer: MarianTokenizer, model: MarianMTModel, device: str = "cpu") -> list[str]: model.eval() translated_chunks = [] start_time = time.time() for i in range(0, len(chunks), 8): batch = chunks[i:i + 8] inputs = tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(device) with torch.no_grad(): with torch.amp.autocast("cuda"): translated = model.generate(**inputs) translated_texts = tokenizer.batch_decode( translated, skip_special_tokens=True ) translated_chunks.extend(translated_texts) gen_time = time.time() - start_time logger.info(f"Generated translation in {gen_time:.2f}s.") return translated_chunks def summarize_with_mistral(text: str, model, tokenizer, prompt_helper: Optional[str] = None) -> str: """ Summarize text using the Mistral summarization model. """ # Format prompt according to model's requirements B_INST, E_INST = "[INST] ", " [/INST]" if prompt_helper is None: prompt_helper = "You are given a transcript of a phone call between a customer and a company agent. The agent is always the one who initiates the call. Summarize the conversation clearly and concisely for an internal admin reader. Focus on the purpose of the call, key discussion points, any actions taken or agreed upon, and any relevant customer sentiment or concerns." prompt = f"""{B_INST}{prompt_helper}\n\n[TEXT_START]\n\n{text}\n\n[TEXT_END]\n\n{E_INST}""" # Tokenize with attention mask inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=16000, return_attention_mask=True ).to(model.device) logger.info(f'Length of input tokens: {len(inputs["input_ids"][0])}') # Move inputs to the same device as model inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): with torch.amp.autocast("cuda"): outputs = model.generate( **inputs, max_new_tokens=256, temperature=0.3, top_p=0.9, top_k=20, repetition_penalty=1.2, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Extract only the generated summary (remove the prompt) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) summary = generated_text.split("[/INST]")[-1].strip() return summary def handle_long_summarization(text: str, tokenizer, model) -> str: """Handle texts longer than context window using recursive summarization""" chunks, _ = split_by_token_limit_with_speakers( text, MAX_SUMMARY_TOKENS / 2, # Use half context for chunking tokenizer ) logger.info(f"Processing {len(chunks)} chunks for recursive summarization") summaries = [] for i, chunk in enumerate(chunks): logger.info(f"Summarizing chunk {i + 1}/{len(chunks)}") summary = summarize_with_mistral( chunk, model, tokenizer ) summaries.append(summary) combined = " ".join(summaries) # Final summary if still too long if len(tokenizer(combined)["input_ids"]) > MAX_SUMMARY_TOKENS: return handle_long_summarization(combined, model, tokenizer) return summarize_with_mistral( combined, model, tokenizer ) class EndpointHandler(): def __init__(self, path=""): # Determine device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Initialize sentiment analysis pipeline self.sentiment_pipeline = None self.sentiment_tokenizer = None try: self.sentiment_tokenizer = AutoTokenizer.from_pretrained("KBLab/megatron-bert-large-swedish-cased-165k") sentiment_model = AutoModelForSequenceClassification.from_pretrained( "KBLab/robust-swedish-sentiment-multiclass") self.sentiment_pipeline = pipeline( "sentiment-analysis", model=sentiment_model, tokenizer=self.sentiment_tokenizer, device=self.device ) logger.info("Sentiment analysis pipeline initialized") except Exception as e: logger.error(f"Error initializing sentiment pipeline: {e}") # Initialize Swedish to English translation self.model_sv_en = None self.tokenizer_sv_en = None try: model_name_sv_en = "Helsinki-NLP/opus-mt-sv-en" self.tokenizer_sv_en = MarianTokenizer.from_pretrained(model_name_sv_en) self.model_sv_en = MarianMTModel.from_pretrained(model_name_sv_en) self.model_sv_en.to(self.device) logger.info("Swedish to English translation model initialized") except Exception as e: logger.error(f"Error initializing sv-en translation: {e}") # Initialize English to Swedish translation self.model_en_sv = None self.tokenizer_en_sv = None try: model_name_en_sv = "Helsinki-NLP/opus-mt-en-sv" self.tokenizer_en_sv = MarianTokenizer.from_pretrained(model_name_en_sv) self.model_en_sv = MarianMTModel.from_pretrained(model_name_en_sv) self.model_en_sv.to(self.device) logger.info("English to Swedish translation model initialized") except Exception as e: logger.error(f"Error initializing en-sv translation: {e}") # Initialize summarization model (using simpler device management) self.summarizer_model = None self.summarizer_tokenizer = None try: model_name = "Trelis/Mistral-7B-Instruct-v0.1-Summarize-16k" self.summarizer_tokenizer = AutoTokenizer.from_pretrained(model_name) # Set pad token if not already set if self.summarizer_tokenizer.pad_token is None: self.summarizer_tokenizer.pad_token = self.summarizer_tokenizer.eos_token # Load model with appropriate dtype # Define the quantization config (4-bit in this case) bnb_config = BitsAndBytesConfig( load_in_4bit=True, # or load_in_8bit=True bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", # "nf4" is better than "fp4" in most cases ) self.summarizer_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", quantization_config=bnb_config, ) logger.info("Summarization model initialized") except Exception as e: logger.error(f"Error initializing summarizer: {e}") def __call__(self, inputs): # Extract inputs and parameters conversation = inputs.get("inputs", "") parameters = inputs.get("parameters", {}) # Validate parameters try: config = InferenceConfig(**parameters) except ValidationError as e: logger.error(f"Error validating parameters: {e}") return {"error": f"Error validating parameters: {e}"} logger.info(f"Processing conversation with parameters: {config}") # Initialize result dictionary result = {} # Perform sentiment analysis if config.sentiment and self.sentiment_pipeline and self.sentiment_tokenizer: logger.info("Performing sentiment analysis...") try: text_chunks_for_sentiment = split_by_token_limit_with_speakers(conversation, MAX_SENTIMENT_TOKENS, self.sentiment_tokenizer) sentiment_results = [] for i, chunk in enumerate(text_chunks_for_sentiment): if chunk.strip(): # Skip empty chunks try: chunk_results = self.sentiment_pipeline(chunk) if isinstance(chunk_results, list): sentiment_results.extend(chunk_results) else: sentiment_results.append(chunk_results) except Exception as chunk_error: logger.warning(f"Failed to analyze sentiment for chunk {i}: {chunk_error}") continue if sentiment_results: labels = [r['label'] for r in sentiment_results] scores = [r['score'] for r in sentiment_results] # Get most common sentiment most_common_label = Counter(labels).most_common(1)[0][0] avg_score = sum(scores) / len(scores) result["sentiment"] = { "label": most_common_label, "score": round(avg_score, 2), "details": { "all_sentiments": labels, "distribution": dict(Counter(labels)) } } else: result["sentiment"] = {"label": "neutral", "score": 0.0} torch.cuda.empty_cache() del text_chunks_for_sentiment gc.collect() logger.info(f"Sentiment analysis completed: {result['sentiment']['label']}") except Exception as e: logger.error(f"Sentiment analysis error: {str(e)}") result["sentiment"] = {"error": f"Sentiment analysis failed: {str(e)}"} # Perform translation and summarization if config.summarize and all([self.summarizer_model, self.model_sv_en, self.model_en_sv]): logger.info("Performing translation and summarization...") try: # Translate Swedish to English for summarization logger.info("Translating to English...") text_chunks_for_translation = split_by_token_limit_with_speakers( conversation, MAX_TRANSLATION_TOKENS, self.tokenizer_sv_en) translated_chunks = translate_chunks( text_chunks_for_translation, self.tokenizer_sv_en, self.model_sv_en, device=self.device ) # Combine translated chunks translated_conversation = " ".join(translated_chunks) logger.info(f"Translated conversation length: {len(translated_conversation)} chars") # Generate summary using Mistral logger.info("Generating summary with Mistral...") # Handle long texts token_count = len(self.summarizer_tokenizer(translated_conversation)["input_ids"]) torch.cuda.empty_cache() gc.collect() if token_count > MAX_SUMMARY_TOKENS: logger.warning(f"Text too long ({token_count} tokens), using recursive summarization") english_summary = handle_long_summarization(translated_conversation, self.summarizer_tokenizer, self.summarizer_model) else: english_summary = summarize_with_mistral( translated_conversation, self.summarizer_model, self.summarizer_tokenizer, ) torch.cuda.empty_cache() del translated_conversation gc.collect() if english_summary: logger.info(f"Generated English summary: {english_summary[:100]}...") # Translate summary back to Swedish logger.info("Translating summary back to Swedish...") swedish_summary_list = translate_chunks( [english_summary], self.tokenizer_en_sv, self.model_en_sv, device=self.device ) swedish_summary = swedish_summary_list[0] if swedish_summary_list else english_summary result["summary"] = { "swedish": swedish_summary, "english": english_summary } else: result["summary"] = {"error": "Failed to generate summary"} logger.info("Summarization completed") except Exception as e: logger.error(f"Translation or summarization error: {str(e)}") result["summary"] = {"error": f"Summarization failed: {str(e)}"} torch.cuda.empty_cache() gc.collect() return result