File size: 6,575 Bytes
a8fc815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Bangla Text Parser using Transformers + Safetensors
Production-grade text understanding for scene planning
"""

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import logging
from typing import List, Dict
import os

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BanglaSceneParser:
    """
    Transformer-based Bangla text parser for scene extraction.
    Uses proper model loading with safetensors and memory optimization.
    """
    
    def __init__(self, model_id: str = "google/mt5-small"):
        """
        Initialize the parser with the specified model.
        
        Args:
            model_id: HuggingFace model identifier
        """
        self.model_id = model_id
        self.tokenizer = None
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        logger.info(f"Initializing BanglaSceneParser with model: {model_id}")
        logger.info(f"Using device: {self.device}")
        
        self._load_model()
    
    def _load_model(self):
        """Load model and tokenizer with proper configuration."""
        try:
            # Load tokenizer with fast implementation
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_id,
                use_fast=True
            )
            
            # Load model with memory optimization
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                self.model_id,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                device_map="auto" if self.device == "cuda" else None,
                load_in_8bit=False  # Set to True if you have limited VRAM
            )
            
            if self.device == "cpu":
                self.model = self.model.to(self.device)
                
            logger.info(f"Model loaded successfully on {self.device}")
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise
    
    def extract_scenes(self, text_bn: str, max_scenes: int = 5) -> List[str]:
        """
        Extract scenes from Bangla text using transformer inference.
        
        Args:
            text_bn: Input Bangla text
            max_scenes: Maximum number of scenes to extract
            
        Returns:
            List of scene descriptions
        """
        if not text_bn.strip():
            return ["Empty text input"]
        
        try:
            # Create optimized prompt
            prompt = self._create_scene_prompt(text_bn, max_scenes)
            
            # Tokenize with proper padding
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.model.device)
            
            # Generate with controlled parameters
            with torch.no_grad():
                output = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    num_beams=3,
                    early_stopping=True,
                    do_sample=False,  # Deterministic output
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode and clean output
            scenes_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
            scenes = self._parse_scenes_output(scenes_text, max_scenes)
            
            logger.info(f"Extracted {len(scenes)} scenes from text")
            return scenes
            
        except Exception as e:
            logger.error(f"Scene extraction failed: {e}")
            return [f"Error processing text: {str(e)}"]
    
    def _create_scene_prompt(self, text_bn: str, max_scenes: int) -> str:
        """Create optimized prompt for scene extraction."""
        return f"""আপনার কাজ: এই বাংলা টেক্সটটিকে সর্বোচ্চ {max_scenes}টি দৃশ্যে ভাগ করুন। প্রতিটি দৃশ্যের জন্য একটি সংক্ষিপ্ত বর্ণনা দিন যা ভিজ্যুয়াল কন্টেন্ট তৈরির জন্য উপযুক্ত।

টেক্সট: {text_bn}

দৃশ্যগুলো:"""
    
    def _parse_scenes_output(self, output_text: str, max_scenes: int) -> List[str]:
        """Parse model output into scene descriptions."""
        scenes = []
        lines = output_text.split('\n')
        
        for line in lines:
            line = line.strip()
            if line and len(scenes) < max_scenes:
                # Clean the line and ensure it's a valid scene description
                if line.startswith(('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.')):
                    scene = line.split('.', 1)[1].strip()
                elif line.startswith('দৃশ্য') or 'সিন' in line:
                    scene = line.split(':', 1)[1].strip() if ':' in line else line
                else:
                    scene = line
                
                if scene and len(scene) > 10:  # Minimum meaningful length
                    scenes.append(scene)
        
        # Fallback if no scenes were extracted
        if not scenes:
            scenes = [f"Scene {i+1}: Visual representation of text segment {i+1}" 
                     for i in range(max_scenes)]
        
        return scenes[:max_scenes]
    
    def get_model_info(self) -> Dict:
        """Get information about the loaded model."""
        return {
            "model_id": self.model_id,
            "device": self.device,
            "vocab_size": len(self.tokenizer) if self.tokenizer else 0,
            "model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else 0
        }

# Global instance for production use
_parser_instance = None

def get_parser(model_id: str = "google/mt5-small") -> BanglaSceneParser:
    """Get or create a global parser instance."""
    global _parser_instance
    if _parser_instance is None or _parser_instance.model_id != model_id:
        _parser_instance = BanglaSceneParser(model_id)
    return _parser_instance

def extract_scenes(text_bn: str, max_scenes: int = 5) -> List[str]:
    """Convenience function for scene extraction."""
    parser = get_parser()
    return parser.extract_scenes(text_bn, max_scenes)