""" Bangla Text Parser using Transformers + Safetensors Production-grade text understanding for scene planning """ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import logging from typing import List, Dict import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class BanglaSceneParser: """ Transformer-based Bangla text parser for scene extraction. Uses proper model loading with safetensors and memory optimization. """ def __init__(self, model_id: str = "google/mt5-small"): """ Initialize the parser with the specified model. Args: model_id: HuggingFace model identifier """ self.model_id = model_id self.tokenizer = None self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing BanglaSceneParser with model: {model_id}") logger.info(f"Using device: {self.device}") self._load_model() def _load_model(self): """Load model and tokenizer with proper configuration.""" try: # Load tokenizer with fast implementation self.tokenizer = AutoTokenizer.from_pretrained( self.model_id, use_fast=True ) # Load model with memory optimization self.model = AutoModelForSeq2SeqLM.from_pretrained( self.model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None, load_in_8bit=False # Set to True if you have limited VRAM ) if self.device == "cpu": self.model = self.model.to(self.device) logger.info(f"Model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load model: {e}") raise def extract_scenes(self, text_bn: str, max_scenes: int = 5) -> List[str]: """ Extract scenes from Bangla text using transformer inference. Args: text_bn: Input Bangla text max_scenes: Maximum number of scenes to extract Returns: List of scene descriptions """ if not text_bn.strip(): return ["Empty text input"] try: # Create optimized prompt prompt = self._create_scene_prompt(text_bn, max_scenes) # Tokenize with proper padding inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(self.model.device) # Generate with controlled parameters with torch.no_grad(): output = self.model.generate( **inputs, max_new_tokens=256, num_beams=3, early_stopping=True, do_sample=False, # Deterministic output pad_token_id=self.tokenizer.eos_token_id ) # Decode and clean output scenes_text = self.tokenizer.decode(output[0], skip_special_tokens=True) scenes = self._parse_scenes_output(scenes_text, max_scenes) logger.info(f"Extracted {len(scenes)} scenes from text") return scenes except Exception as e: logger.error(f"Scene extraction failed: {e}") return [f"Error processing text: {str(e)}"] def _create_scene_prompt(self, text_bn: str, max_scenes: int) -> str: """Create optimized prompt for scene extraction.""" return f"""আপনার কাজ: এই বাংলা টেক্সটটিকে সর্বোচ্চ {max_scenes}টি দৃশ্যে ভাগ করুন। প্রতিটি দৃশ্যের জন্য একটি সংক্ষিপ্ত বর্ণনা দিন যা ভিজ্যুয়াল কন্টেন্ট তৈরির জন্য উপযুক্ত। টেক্সট: {text_bn} দৃশ্যগুলো:""" def _parse_scenes_output(self, output_text: str, max_scenes: int) -> List[str]: """Parse model output into scene descriptions.""" scenes = [] lines = output_text.split('\n') for line in lines: line = line.strip() if line and len(scenes) < max_scenes: # Clean the line and ensure it's a valid scene description if line.startswith(('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.')): scene = line.split('.', 1)[1].strip() elif line.startswith('দৃশ্য') or 'সিন' in line: scene = line.split(':', 1)[1].strip() if ':' in line else line else: scene = line if scene and len(scene) > 10: # Minimum meaningful length scenes.append(scene) # Fallback if no scenes were extracted if not scenes: scenes = [f"Scene {i+1}: Visual representation of text segment {i+1}" for i in range(max_scenes)] return scenes[:max_scenes] def get_model_info(self) -> Dict: """Get information about the loaded model.""" return { "model_id": self.model_id, "device": self.device, "vocab_size": len(self.tokenizer) if self.tokenizer else 0, "model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else 0 } # Global instance for production use _parser_instance = None def get_parser(model_id: str = "google/mt5-small") -> BanglaSceneParser: """Get or create a global parser instance.""" global _parser_instance if _parser_instance is None or _parser_instance.model_id != model_id: _parser_instance = BanglaSceneParser(model_id) return _parser_instance def extract_scenes(text_bn: str, max_scenes: int = 5) -> List[str]: """Convenience function for scene extraction.""" parser = get_parser() return parser.extract_scenes(text_bn, max_scenes)