| import gc |
| import logging |
| import re |
| import time |
| from collections import Counter |
| from typing import List, Optional |
|
|
| import torch |
| from pydantic import BaseModel, ValidationError |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSequenceClassification, |
| MarianTokenizer, |
| MarianMTModel, |
| pipeline, |
| PreTrainedTokenizer, |
| AutoModelForCausalLM, BitsAndBytesConfig |
| ) |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
| logger = logging.getLogger(__name__) |
|
|
| B_INST, E_INST = "[INST] ", " [/INST]" |
|
|
| MAX_SUMMARY_TOKENS = 15900 |
| MAX_TRANSLATION_TOKENS = 256 |
| MAX_SENTIMENT_TOKENS = 256 |
| class InferenceConfig(BaseModel): |
| sentiment: bool = False |
| summarize: bool = False |
| prompt: Optional[str] = None |
|
|
|
|
| def split_by_token_limit_with_speakers(text: str, max_tokens: int, tokenizer: PreTrainedTokenizer) -> List[str]: |
| """ |
| Split transcript text into chunks with max token length, preserving speaker turns. |
| """ |
| |
| speaker_pattern = re.compile(r"(Agent:|Customer:)") |
| parts = speaker_pattern.split(text) |
| if not parts[0].strip(): |
| parts = parts[1:] |
|
|
| |
| turns = [(parts[i], parts[i + 1].strip()) for i in range(0, len(parts), 2)] |
|
|
| chunks = [] |
| current_chunk = "" |
| current_token_count = 0 |
|
|
| for speaker, utterance in turns: |
| turn_text = f"{speaker} {utterance}" |
| tokenized = tokenizer(turn_text, add_special_tokens=False, return_attention_mask=True) |
| turn_tokens = len(tokenized["input_ids"]) |
|
|
| |
| if current_token_count + turn_tokens <= max_tokens: |
| if current_chunk: |
| current_chunk += " " + turn_text |
| else: |
| current_chunk = turn_text |
| current_token_count += turn_tokens |
| else: |
| |
| if turn_tokens > max_tokens: |
| words = utterance.split() |
| partial = [] |
| for word in words: |
| candidate = f"{speaker} {' '.join(partial + [word])}" |
| tokens = len(tokenizer(candidate, add_special_tokens=False)["input_ids"]) |
| if tokens > max_tokens: |
| |
| if partial: |
| chunks.append(f"{speaker} {' '.join(partial)}") |
| partial = [word] |
| else: |
| partial.append(word) |
| if partial: |
| chunks.append(f"{speaker} {' '.join(partial)}") |
| current_chunk = "" |
| current_token_count = 0 |
| else: |
| |
| if current_chunk: |
| chunks.append(current_chunk) |
| current_chunk = turn_text |
| current_token_count = turn_tokens |
|
|
| if current_chunk: |
| chunks.append(current_chunk) |
|
|
| return chunks |
|
|
|
|
| def translate_chunks(chunks: list[str], tokenizer: MarianTokenizer, model: MarianMTModel, device: str = "cpu") -> list[str]: |
| model.eval() |
| translated_chunks = [] |
| start_time = time.time() |
| for i in range(0, len(chunks), 8): |
| batch = chunks[i:i + 8] |
| inputs = tokenizer( |
| batch, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ).to(device) |
|
|
| with torch.no_grad(): |
| with torch.amp.autocast("cuda"): |
| translated = model.generate(**inputs) |
|
|
| translated_texts = tokenizer.batch_decode( |
| translated, |
| skip_special_tokens=True |
| ) |
| translated_chunks.extend(translated_texts) |
| gen_time = time.time() - start_time |
| logger.info(f"Generated translation in {gen_time:.2f}s.") |
| return translated_chunks |
|
|
|
|
| def summarize_with_mistral(text: str, model, tokenizer, prompt_helper: Optional[str] = None) -> str: |
| """ |
| Summarize text using the Mistral summarization model. |
| """ |
| |
| B_INST, E_INST = "[INST] ", " [/INST]" |
| if prompt_helper is None: |
| prompt_helper = "You are given a transcript of a phone call between a customer and a company agent. The agent is always the one who initiates the call. Summarize the conversation clearly and concisely for an internal admin reader. Focus on the purpose of the call, key discussion points, any actions taken or agreed upon, and any relevant customer sentiment or concerns." |
| prompt = f"""{B_INST}{prompt_helper}\n\n[TEXT_START]\n\n{text}\n\n[TEXT_END]\n\n{E_INST}""" |
| |
| inputs = tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=16000, |
| return_attention_mask=True |
| ).to(model.device) |
| logger.info(f'Length of input tokens: {len(inputs["input_ids"][0])}') |
| |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| with torch.amp.autocast("cuda"): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=256, |
| temperature=0.3, |
| top_p=0.9, |
| top_k=20, |
| repetition_penalty=1.2, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
|
|
| |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| summary = generated_text.split("[/INST]")[-1].strip() |
|
|
| return summary |
|
|
|
|
| def handle_long_summarization(text: str, tokenizer, model) -> str: |
| """Handle texts longer than context window using recursive summarization""" |
| chunks, _ = split_by_token_limit_with_speakers( |
| text, |
| MAX_SUMMARY_TOKENS / 2, |
| tokenizer |
| ) |
|
|
| logger.info(f"Processing {len(chunks)} chunks for recursive summarization") |
| summaries = [] |
|
|
| for i, chunk in enumerate(chunks): |
| logger.info(f"Summarizing chunk {i + 1}/{len(chunks)}") |
| summary = summarize_with_mistral( |
| chunk, |
| model, |
| tokenizer |
| ) |
| summaries.append(summary) |
|
|
| combined = " ".join(summaries) |
|
|
| |
| if len(tokenizer(combined)["input_ids"]) > MAX_SUMMARY_TOKENS: |
| return handle_long_summarization(combined, model, tokenizer) |
|
|
| return summarize_with_mistral( |
| combined, |
| model, |
| tokenizer |
| ) |
|
|
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Using device: {self.device}") |
|
|
| |
| self.sentiment_pipeline = None |
| self.sentiment_tokenizer = None |
| try: |
| self.sentiment_tokenizer = AutoTokenizer.from_pretrained("KBLab/megatron-bert-large-swedish-cased-165k") |
| sentiment_model = AutoModelForSequenceClassification.from_pretrained( |
| "KBLab/robust-swedish-sentiment-multiclass") |
| self.sentiment_pipeline = pipeline( |
| "sentiment-analysis", |
| model=sentiment_model, |
| tokenizer=self.sentiment_tokenizer, |
| device=self.device |
| ) |
| logger.info("Sentiment analysis pipeline initialized") |
| except Exception as e: |
| logger.error(f"Error initializing sentiment pipeline: {e}") |
|
|
| |
| self.model_sv_en = None |
| self.tokenizer_sv_en = None |
| try: |
| model_name_sv_en = "Helsinki-NLP/opus-mt-sv-en" |
| self.tokenizer_sv_en = MarianTokenizer.from_pretrained(model_name_sv_en) |
| self.model_sv_en = MarianMTModel.from_pretrained(model_name_sv_en) |
| self.model_sv_en.to(self.device) |
| logger.info("Swedish to English translation model initialized") |
| except Exception as e: |
| logger.error(f"Error initializing sv-en translation: {e}") |
|
|
| |
| self.model_en_sv = None |
| self.tokenizer_en_sv = None |
| try: |
| model_name_en_sv = "Helsinki-NLP/opus-mt-en-sv" |
| self.tokenizer_en_sv = MarianTokenizer.from_pretrained(model_name_en_sv) |
| self.model_en_sv = MarianMTModel.from_pretrained(model_name_en_sv) |
| self.model_en_sv.to(self.device) |
| logger.info("English to Swedish translation model initialized") |
| except Exception as e: |
| logger.error(f"Error initializing en-sv translation: {e}") |
|
|
| |
| self.summarizer_model = None |
| self.summarizer_tokenizer = None |
| try: |
| model_name = "Trelis/Mistral-7B-Instruct-v0.1-Summarize-16k" |
| self.summarizer_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| |
| if self.summarizer_tokenizer.pad_token is None: |
| self.summarizer_tokenizer.pad_token = self.summarizer_tokenizer.eos_token |
|
|
| |
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype="float16", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
|
|
| self.summarizer_model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| quantization_config=bnb_config, |
| ) |
|
|
| logger.info("Summarization model initialized") |
| except Exception as e: |
| logger.error(f"Error initializing summarizer: {e}") |
|
|
| def __call__(self, inputs): |
| |
| conversation = inputs.get("inputs", "") |
| parameters = inputs.get("parameters", {}) |
|
|
| |
| try: |
| config = InferenceConfig(**parameters) |
| except ValidationError as e: |
| logger.error(f"Error validating parameters: {e}") |
| return {"error": f"Error validating parameters: {e}"} |
|
|
| logger.info(f"Processing conversation with parameters: {config}") |
|
|
| |
| result = {} |
|
|
| |
| if config.sentiment and self.sentiment_pipeline and self.sentiment_tokenizer: |
| logger.info("Performing sentiment analysis...") |
| try: |
| text_chunks_for_sentiment = split_by_token_limit_with_speakers(conversation, MAX_SENTIMENT_TOKENS, |
| self.sentiment_tokenizer) |
|
|
| sentiment_results = [] |
| for i, chunk in enumerate(text_chunks_for_sentiment): |
| if chunk.strip(): |
| try: |
| chunk_results = self.sentiment_pipeline(chunk) |
| if isinstance(chunk_results, list): |
| sentiment_results.extend(chunk_results) |
| else: |
| sentiment_results.append(chunk_results) |
| except Exception as chunk_error: |
| logger.warning(f"Failed to analyze sentiment for chunk {i}: {chunk_error}") |
| continue |
|
|
| if sentiment_results: |
| labels = [r['label'] for r in sentiment_results] |
| scores = [r['score'] for r in sentiment_results] |
|
|
| |
| most_common_label = Counter(labels).most_common(1)[0][0] |
| avg_score = sum(scores) / len(scores) |
|
|
| result["sentiment"] = { |
| "label": most_common_label, |
| "score": round(avg_score, 2), |
| "details": { |
| "all_sentiments": labels, |
| "distribution": dict(Counter(labels)) |
| } |
| } |
| else: |
| result["sentiment"] = {"label": "neutral", "score": 0.0} |
| torch.cuda.empty_cache() |
| del text_chunks_for_sentiment |
| gc.collect() |
| logger.info(f"Sentiment analysis completed: {result['sentiment']['label']}") |
|
|
| except Exception as e: |
| logger.error(f"Sentiment analysis error: {str(e)}") |
| result["sentiment"] = {"error": f"Sentiment analysis failed: {str(e)}"} |
|
|
| |
| if config.summarize and all([self.summarizer_model, self.model_sv_en, self.model_en_sv]): |
| logger.info("Performing translation and summarization...") |
| try: |
| |
| logger.info("Translating to English...") |
| text_chunks_for_translation = split_by_token_limit_with_speakers( |
| conversation, MAX_TRANSLATION_TOKENS, self.tokenizer_sv_en) |
|
|
| translated_chunks = translate_chunks( |
| text_chunks_for_translation, |
| self.tokenizer_sv_en, |
| self.model_sv_en, |
| device=self.device |
| ) |
|
|
| |
| translated_conversation = " ".join(translated_chunks) |
| logger.info(f"Translated conversation length: {len(translated_conversation)} chars") |
| |
| logger.info("Generating summary with Mistral...") |
| |
| token_count = len(self.summarizer_tokenizer(translated_conversation)["input_ids"]) |
| torch.cuda.empty_cache() |
| gc.collect() |
| if token_count > MAX_SUMMARY_TOKENS: |
| logger.warning(f"Text too long ({token_count} tokens), using recursive summarization") |
| english_summary = handle_long_summarization(translated_conversation, self.summarizer_tokenizer, self.summarizer_model) |
| else: |
| english_summary = summarize_with_mistral( |
| translated_conversation, |
| self.summarizer_model, |
| self.summarizer_tokenizer, |
| ) |
| torch.cuda.empty_cache() |
| del translated_conversation |
| gc.collect() |
| if english_summary: |
| logger.info(f"Generated English summary: {english_summary[:100]}...") |
|
|
| |
| logger.info("Translating summary back to Swedish...") |
| swedish_summary_list = translate_chunks( |
| [english_summary], |
| self.tokenizer_en_sv, |
| self.model_en_sv, |
| device=self.device |
| ) |
| swedish_summary = swedish_summary_list[0] if swedish_summary_list else english_summary |
|
|
| result["summary"] = { |
| "swedish": swedish_summary, |
| "english": english_summary |
| } |
| else: |
| result["summary"] = {"error": "Failed to generate summary"} |
|
|
| logger.info("Summarization completed") |
|
|
| except Exception as e: |
| logger.error(f"Translation or summarization error: {str(e)}") |
| result["summary"] = {"error": f"Summarization failed: {str(e)}"} |
| torch.cuda.empty_cache() |
| gc.collect() |
| return result |
|
|