Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import ctranslate2 | |
| import transformers | |
| from datasets import load_dataset | |
| import pandas as pd | |
| MODEL_DIR = "models" | |
| CT2_MODEL_DIR = "models" # Set to models for HF Spaces compatibility (outputs model.bin here) | |
| def optimize_model(): | |
| print("Converting model to CTranslate2 format...") | |
| # Ensure source files exist | |
| if not any(f for f in os.listdir(MODEL_DIR) if f.startswith("pytorch_model") or f.endswith(".safetensors")): | |
| print(f"Error: No source weights found in {MODEL_DIR}. Cannot convert.") | |
| return | |
| # Converter for mBART | |
| converter = ctranslate2.converters.TransformersConverter( | |
| MODEL_DIR, | |
| activation_scales=None, | |
| copy_files=["tokenizer.json", "sentencepiece.bpe.model"] # Ensure tokenizer files are copied | |
| ) | |
| # Quantization often helps speed. Int8 is common. | |
| converter.convert( | |
| CT2_MODEL_DIR, | |
| quantization="int8", | |
| force=True | |
| ) | |
| print(f"Model converted and saved to {CT2_MODEL_DIR}") | |
| def benchmark(): | |
| print("\nStarting Benchmarking...") | |
| # Load original model (for size check only, inference might be slow to load) | |
| # original_size = get_dir_size(MODEL_DIR) | |
| # ct2_size = get_dir_size(CT2_MODEL_DIR) | |
| # print(f"Original Model Size: {original_size / 1e6:.2f} MB") | |
| # print(f"Optimized Model Size: {ct2_size / 1e6:.2f} MB") | |
| # Load CT2 model | |
| translator = ctranslate2.Translator(CT2_MODEL_DIR) | |
| tokenizer = transformers.MBart50TokenizerFast.from_pretrained(MODEL_DIR) | |
| # Test data | |
| texts = ["Namaste", "Hello", "How are you", "Good morning", "India"] | |
| target_lang = "hi_IN" # Test with Hindi | |
| tokenizer.src_lang = "en_XX" | |
| start_time = time.time() | |
| # Tokenize | |
| source = tokenizer(texts, return_tensors="pt", padding=True) | |
| input_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in source["input_ids"]] | |
| # Remove padding/eos if needed specifically for CT2, but usually it handles list of strings | |
| # Actually CT2 expects list of list of str tokens | |
| # Let's re-do properly for CT2 text input | |
| input_tokens_batch = [] | |
| for text in texts: | |
| tokens = tokenizer.tokenize(text) | |
| input_tokens_batch.append(tokens) | |
| # Translate | |
| results = translator.translate_batch( | |
| input_tokens_batch, | |
| target_prefix=[[target_lang]] * len(texts) # Force target lang | |
| ) | |
| end_time = time.time() | |
| decoded = [] | |
| for result in results: | |
| decoded.append(tokenizer.decode(tokenizer.convert_tokens_to_ids(result.hypotheses[0]))) | |
| duration = end_time - start_time | |
| print(f"Inference Time for {len(texts)} sentences: {duration:.4f}s") | |
| print(f"Speed: {len(texts)/duration:.2f} sentences/s") | |
| for src, tgt in zip(texts, decoded): | |
| print(f"{src} -> {tgt}") | |
| def get_dir_size(path): | |
| total = 0 | |
| with os.scandir(path) as it: | |
| for entry in it: | |
| if entry.is_file(): | |
| total += entry.stat().st_size | |
| elif entry.is_dir(): | |
| total += get_dir_size(entry.path) | |
| return total | |
| if __name__ == "__main__": | |
| if not os.path.exists(MODEL_DIR): | |
| print(f"Model directory {MODEL_DIR} not found. Please train first.") | |
| else: | |
| optimize_model() | |
| benchmark() | |