SLM / app.py
ruhzi's picture
Update app.py
6a1eaed verified
import os
os.environ["OMP_NUM_THREADS"] = "2"
import gradio as gr
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import hf_hub_download
from threading import Thread, Event
torch.set_num_threads(2)
model_path = "ruhzi/Indian_History_SLM"
tokenizer = AutoTokenizer.from_pretrained(model_path)
template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja")
with open(template_file, "r", encoding="utf-8") as f:
tokenizer.chat_template = f.read()
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.float32,
low_cpu_mem_usage=True
)
stop_event = Event()
def chat_inference(message, history):
global stop_event
stop_event.set()
stop_event = Event()
current_stop = stop_event
# Gradio 6.x: history is a list of {"role": "...", "content": "..."} dicts
messages = []
recent_history = history[-6:] if len(history) > 6 else history # 6 = 3 turns
for entry in recent_history:
messages.append({"role": entry["role"], "content": entry["content"]})
messages.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = tokenizer([input_text], return_tensors="pt").to("cpu")
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.8,
)
t = Thread(target=model.generate, kwargs=generate_kwargs, daemon=True)
t.start()
partial_message = ""
try:
for new_token in streamer:
if current_stop.is_set():
for _ in streamer:
pass
break
partial_message += new_token
yield partial_message
finally:
del inputs
gc.collect()
demo = gr.ChatInterface(
fn=chat_inference,
title="Indian History SLM",
description="Ask me anything about Indian History!",
stop_btn="Stop",
concurrency_limit=1,
)
if __name__ == "__main__":
demo.launch()