safikhan's picture
Create Readme.md
d23ffa7 verified

Sample Code to run the model

import os
import argparse
import logging
import time
from tqdm import tqdm
import pandas as pd

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer


logging.basicConfig(level=logging.INFO)

_CODE2LANG = {
    "as": "Assamese",
    "bn": "Bengali",
    "en": "English",
    "gu": "Gujarati",
    "hi": "Hindi",
    "kn": "Kannada",
    "ml": "Malayalam",
    "mr": "Marathi",
    "ne": "Nepali",
    "or": "Odia",
    "pa": "Punjabi",
    "sa": "Sanskrit",
    "ta": "Tamil",
    "te": "Telugu",
    "ur": "Urdu"
}

def main():
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", padding_side="left")

    src = [
        "When I was young, I used to go to the park every day.",
        "We watched a new movie last week, which was very inspiring.",
        "If you had met me at that time, we would have gone out to eat.",
        "My friend has invited me to his birthday party, and I will give him a gift."
    ]
    
    tgt_lang = "hi"
    model = "ai4bharat/IndicTrans3-beta"
    prompt_dicts = []
    for s in src:
        prompt_dicts.append([{"role": "user", "content": f"Translate the following text to {_CODE2LANG[tgt_lang]}: {s}"}])
    prompts = [tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) for prompt in prompt_dicts]

    print(f"Loading model from: {model}")
    llm = LLM(
        model=model,
        trust_remote_code=True,
        tensor_parallel_size=1,
        download_dir="/projects/data/llmteam/safi/itv3/cache"
    )

    sampling_params = SamplingParams(
        temperature=1.0, #set an appropriate temperature value
        max_tokens=4096,
        repetition_penalty=1.0,
    )

    outputs = llm.generate(prompts, sampling_params)
    results = []
    for input_, output in zip(src, outputs):
        generated_text = output.outputs[0].text
        results.append({
            'input': input_,
            'output': generated_text
        })

    print(results)

if __name__ == "__main__":
    main()