memo / models /text /bangla_parser.py
likhonsheikh's picture
Upload Memo: Production-grade Transformers + Safetensors implementation
a8fc815 verified
"""
Bangla Text Parser using Transformers + Safetensors
Production-grade text understanding for scene planning
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import logging
from typing import List, Dict
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class BanglaSceneParser:
"""
Transformer-based Bangla text parser for scene extraction.
Uses proper model loading with safetensors and memory optimization.
"""
def __init__(self, model_id: str = "google/mt5-small"):
"""
Initialize the parser with the specified model.
Args:
model_id: HuggingFace model identifier
"""
self.model_id = model_id
self.tokenizer = None
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing BanglaSceneParser with model: {model_id}")
logger.info(f"Using device: {self.device}")
self._load_model()
def _load_model(self):
"""Load model and tokenizer with proper configuration."""
try:
# Load tokenizer with fast implementation
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
use_fast=True
)
# Load model with memory optimization
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_id,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
load_in_8bit=False # Set to True if you have limited VRAM
)
if self.device == "cpu":
self.model = self.model.to(self.device)
logger.info(f"Model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
def extract_scenes(self, text_bn: str, max_scenes: int = 5) -> List[str]:
"""
Extract scenes from Bangla text using transformer inference.
Args:
text_bn: Input Bangla text
max_scenes: Maximum number of scenes to extract
Returns:
List of scene descriptions
"""
if not text_bn.strip():
return ["Empty text input"]
try:
# Create optimized prompt
prompt = self._create_scene_prompt(text_bn, max_scenes)
# Tokenize with proper padding
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.model.device)
# Generate with controlled parameters
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=256,
num_beams=3,
early_stopping=True,
do_sample=False, # Deterministic output
pad_token_id=self.tokenizer.eos_token_id
)
# Decode and clean output
scenes_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
scenes = self._parse_scenes_output(scenes_text, max_scenes)
logger.info(f"Extracted {len(scenes)} scenes from text")
return scenes
except Exception as e:
logger.error(f"Scene extraction failed: {e}")
return [f"Error processing text: {str(e)}"]
def _create_scene_prompt(self, text_bn: str, max_scenes: int) -> str:
"""Create optimized prompt for scene extraction."""
return f"""আপনার কাজ: এই বাংলা টেক্সটটিকে সর্বোচ্চ {max_scenes}টি দৃশ্যে ভাগ করুন। প্রতিটি দৃশ্যের জন্য একটি সংক্ষিপ্ত বর্ণনা দিন যা ভিজ্যুয়াল কন্টেন্ট তৈরির জন্য উপযুক্ত।
টেক্সট: {text_bn}
দৃশ্যগুলো:"""
def _parse_scenes_output(self, output_text: str, max_scenes: int) -> List[str]:
"""Parse model output into scene descriptions."""
scenes = []
lines = output_text.split('\n')
for line in lines:
line = line.strip()
if line and len(scenes) < max_scenes:
# Clean the line and ensure it's a valid scene description
if line.startswith(('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.')):
scene = line.split('.', 1)[1].strip()
elif line.startswith('দৃশ্য') or 'সিন' in line:
scene = line.split(':', 1)[1].strip() if ':' in line else line
else:
scene = line
if scene and len(scene) > 10: # Minimum meaningful length
scenes.append(scene)
# Fallback if no scenes were extracted
if not scenes:
scenes = [f"Scene {i+1}: Visual representation of text segment {i+1}"
for i in range(max_scenes)]
return scenes[:max_scenes]
def get_model_info(self) -> Dict:
"""Get information about the loaded model."""
return {
"model_id": self.model_id,
"device": self.device,
"vocab_size": len(self.tokenizer) if self.tokenizer else 0,
"model_parameters": sum(p.numel() for p in self.model.parameters()) if self.model else 0
}
# Global instance for production use
_parser_instance = None
def get_parser(model_id: str = "google/mt5-small") -> BanglaSceneParser:
"""Get or create a global parser instance."""
global _parser_instance
if _parser_instance is None or _parser_instance.model_id != model_id:
_parser_instance = BanglaSceneParser(model_id)
return _parser_instance
def extract_scenes(text_bn: str, max_scenes: int = 5) -> List[str]:
"""Convenience function for scene extraction."""
parser = get_parser()
return parser.extract_scenes(text_bn, max_scenes)