|
|
import os |
|
|
import tarfile |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
archive_path = hf_hub_download( |
|
|
repo_id="SamOrion/Llama_3.2_3b_Hindi_Pruned", |
|
|
filename="llama-3.2-3b-hindi-pruned.tar.gz", |
|
|
repo_type="model", |
|
|
) |
|
|
|
|
|
|
|
|
extract_dir = "./model" |
|
|
os.makedirs(extract_dir, exist_ok=True) |
|
|
|
|
|
with tarfile.open(archive_path, "r:gz") as tar: |
|
|
for member in tar.getmembers(): |
|
|
|
|
|
parts = member.name.split("/", 1) |
|
|
if len(parts) == 2: |
|
|
member.name = parts[1] |
|
|
tar.extract(member, path=extract_dir) |
|
|
|
|
|
|
|
|
config_path = os.path.join(extract_dir, "config.json") |
|
|
if not os.path.isfile(config_path): |
|
|
raise FileNotFoundError(f"config.json not found in {extract_dir}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(extract_dir) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
extract_dir, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
|
|
|
def chat_fn(prompt, history): |
|
|
history = history or [] |
|
|
history.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
input_length = inputs["input_ids"].shape[1] |
|
|
|
|
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
|
|
|
|
|
|
|
generated_ids = outputs[0][input_length:] |
|
|
response = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
history.append({"role": "assistant", "content": response}) |
|
|
return history, "" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🌐 Indus 3.0 Hindi LLM Demo") |
|
|
chat = gr.Chatbot(type="messages") |
|
|
msg = gr.Textbox(placeholder="Type here…") |
|
|
clear = gr.Button("Clear") |
|
|
msg.submit(chat_fn, [msg, chat], [chat, msg]) |
|
|
clear.click(lambda: ([], ""), None, [chat, msg]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) |
|
|
|