|
|
""" |
|
|
Bangla Text Parser using Transformers + Safetensors |
|
|
Production-grade text understanding for scene planning |
|
|
""" |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
import logging |
|
|
from typing import List, Dict |
|
|
import os |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BanglaSceneParser: |
|
|
""" |
|
|
Transformer-based Bangla text parser for scene extraction. |
|
|
Uses proper model loading with safetensors and memory optimization. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_id: str = "google/mt5-small"): |
|
|
""" |
|
|
Initialize the parser with the specified model. |
|
|
|
|
|
Args: |
|
|
model_id: HuggingFace model identifier |
|
|
""" |
|
|
self.model_id = model_id |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
logger.info(f"Initializing BanglaSceneParser with model: {model_id}") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load model and tokenizer with proper configuration.""" |
|
|
try: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_id, |
|
|
use_fast=True |
|
|
) |
|
|
|
|
|
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
self.model_id, |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
|
device_map="auto" if self.device == "cuda" else None, |
|
|
load_in_8bit=False |
|
|
) |
|
|
|
|
|
if self.device == "cpu": |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
logger.info(f"Model loaded successfully on {self.device}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise |
|
|
|
|
|
def extract_scenes(self, text_bn: str, max_scenes: int = 5) -> List[str]: |
|
|
""" |
|
|
Extract scenes from Bangla text using transformer inference. |
|
|
|
|
|
Args: |
|
|
text_bn: Input Bangla text |
|
|
max_scenes: Maximum number of scenes to extract |
|
|
|
|
|
Returns: |
|
|
List of scene descriptions |
|
|
""" |
|
|
if not text_bn.strip(): |
|
|
return ["Empty text input"] |
|
|
|
|
|
try: |
|
|
|
|
|
prompt = self._create_scene_prompt(text_bn, max_scenes) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(self.model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
num_beams=3, |
|
|
early_stopping=True, |
|
|
do_sample=False, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
scenes_text = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
scenes = self._parse_scenes_output(scenes_text, max_scenes) |
|
|
|
|
|
logger.info(f"Extracted {len(scenes)} scenes from text") |
|
|
return scenes |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Scene extraction failed: {e}") |
|
|
return [f"Error processing text: {str(e)}"] |
|
|
|
|
|
def _create_scene_prompt(self, text_bn: str, max_scenes: int) -> str: |
|
|
"""Create optimized prompt for scene extraction.""" |
|
|
return f"""আপনার কাজ: এই বাংলা টেক্সটটিকে সর্বোচ্চ {max_scenes}টি দৃশ্যে ভাগ করুন। প্রতিটি দৃশ্যের জন্য একটি সংক্ষিপ্ত বর্ণনা দিন যা ভিজ্যুয়াল কন্টেন্ট তৈরির জন্য উপযুক্ত। |
|
|
|
|
|
টেক্সট: {text_bn} |
|
|
|
|
|
দৃশ্যগুলো:""" |
|
|
|
|
|
def _parse_scenes_output(self, output_text: str, max_scenes: int) -> List[str]: |
|
|
"""Parse model output into scene descriptions.""" |
|
|
scenes = [] |
|
|
lines = output_text.split('\n') |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line and len(scenes) < max_scenes: |
|
|
|
|
|
if line.startswith(('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.')): |
|
|
scene = line.split('.', 1)[1].strip() |
|
|
elif line.startswith('দৃশ্য') or 'সিন' in line: |
|
|
scene = line.split(':', 1)[1].strip() if ':' in line else line |
|
|
else: |
|
|
scene = line |
|
|
|
|
|
if scene and len(scene) > 10: |
|
|
scenes.append(scene) |
|
|
|
|
|
|
|
|
if not scenes: |
|
|
scenes = [f"Scene {i+1}: Visual representation of text segment {i+1}" |
|
|
for i in range(max_scenes)] |
|
|
|
|
|
return scenes[:max_scenes] |
|
|
|
|
|
def get_model_info(self) -> Dict: |
|
|
"""Get information about the loaded model.""" |
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"device": self.device, |
|
|
"vocab_size": len(self.tokenizer) if self.tokenizer else 0, |
|
|
"model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else 0 |
|
|
} |
|
|
|
|
|
|
|
|
_parser_instance = None |
|
|
|
|
|
def get_parser(model_id: str = "google/mt5-small") -> BanglaSceneParser: |
|
|
"""Get or create a global parser instance.""" |
|
|
global _parser_instance |
|
|
if _parser_instance is None or _parser_instance.model_id != model_id: |
|
|
_parser_instance = BanglaSceneParser(model_id) |
|
|
return _parser_instance |
|
|
|
|
|
def extract_scenes(text_bn: str, max_scenes: int = 5) -> List[str]: |
|
|
"""Convenience function for scene extraction.""" |
|
|
parser = get_parser() |
|
|
return parser.extract_scenes(text_bn, max_scenes) |