import os os.environ["OMP_NUM_THREADS"] = "2" import gradio as gr import torch import gc from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from huggingface_hub import hf_hub_download from threading import Thread, Event torch.set_num_threads(2) model_path = "ruhzi/Indian_History_SLM" tokenizer = AutoTokenizer.from_pretrained(model_path) template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja") with open(template_file, "r", encoding="utf-8") as f: tokenizer.chat_template = f.read() model = AutoModelForCausalLM.from_pretrained( model_path, dtype=torch.float32, low_cpu_mem_usage=True ) stop_event = Event() def chat_inference(message, history): global stop_event stop_event.set() stop_event = Event() current_stop = stop_event # Gradio 6.x: history is a list of {"role": "...", "content": "..."} dicts messages = [] recent_history = history[-6:] if len(history) > 6 else history # 6 = 3 turns for entry in recent_history: messages.append({"role": entry["role"], "content": entry["content"]}) messages.append({"role": "user", "content": message}) input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) inputs = tokenizer([input_text], return_tensors="pt").to("cpu") streamer = TextIteratorStreamer( tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.8, ) t = Thread(target=model.generate, kwargs=generate_kwargs, daemon=True) t.start() partial_message = "" try: for new_token in streamer: if current_stop.is_set(): for _ in streamer: pass break partial_message += new_token yield partial_message finally: del inputs gc.collect() demo = gr.ChatInterface( fn=chat_inference, title="Indian History SLM", description="Ask me anything about Indian History!", stop_btn="Stop", concurrency_limit=1, ) if __name__ == "__main__": demo.launch()