File size: 2,369 Bytes
9aed480 fd449f8 c0a8087 9aed480 4185a6c c0741fa c0a8087 9aed480 c0741fa c0a8087 4185a6c c0a8087 6a1eaed 9aed480 c0a8087 fd449f8 c0741fa 82cd61c c0741fa 5acdaf9 c0741fa 5acdaf9 254de85 5acdaf9 254de85 c0741fa 254de85 fd449f8 c0a8087 4185a6c c0741fa c0a8087 c0741fa 9aed480 c0a8087 c0741fa 254de85 c0741fa c0a8087 254de85 4185a6c c0a8087 c0741fa 254de85 c0741fa c0a8087 82cd61c c0a8087 82cd61c a94700b 5acdaf9 c0741fa fd449f8 c0a8087 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | import os
os.environ["OMP_NUM_THREADS"] = "2"
import gradio as gr
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import hf_hub_download
from threading import Thread, Event
torch.set_num_threads(2)
model_path = "ruhzi/Indian_History_SLM"
tokenizer = AutoTokenizer.from_pretrained(model_path)
template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja")
with open(template_file, "r", encoding="utf-8") as f:
tokenizer.chat_template = f.read()
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.float32,
low_cpu_mem_usage=True
)
stop_event = Event()
def chat_inference(message, history):
global stop_event
stop_event.set()
stop_event = Event()
current_stop = stop_event
# Gradio 6.x: history is a list of {"role": "...", "content": "..."} dicts
messages = []
recent_history = history[-6:] if len(history) > 6 else history # 6 = 3 turns
for entry in recent_history:
messages.append({"role": entry["role"], "content": entry["content"]})
messages.append({"role": "user", "content": message})
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
inputs = tokenizer([input_text], return_tensors="pt").to("cpu")
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.8,
)
t = Thread(target=model.generate, kwargs=generate_kwargs, daemon=True)
t.start()
partial_message = ""
try:
for new_token in streamer:
if current_stop.is_set():
for _ in streamer:
pass
break
partial_message += new_token
yield partial_message
finally:
del inputs
gc.collect()
demo = gr.ChatInterface(
fn=chat_inference,
title="Indian History SLM",
description="Ask me anything about Indian History!",
stop_btn="Stop",
concurrency_limit=1,
)
if __name__ == "__main__":
demo.launch() |