| import os |
| os.environ["OMP_NUM_THREADS"] = "2" |
|
|
| import gradio as gr |
| import torch |
| import gc |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
| from huggingface_hub import hf_hub_download |
| from threading import Thread, Event |
|
|
| torch.set_num_threads(2) |
|
|
| model_path = "ruhzi/Indian_History_SLM" |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja") |
| with open(template_file, "r", encoding="utf-8") as f: |
| tokenizer.chat_template = f.read() |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| dtype=torch.float32, |
| low_cpu_mem_usage=True |
| ) |
|
|
| stop_event = Event() |
|
|
| def chat_inference(message, history): |
| global stop_event |
|
|
| stop_event.set() |
| stop_event = Event() |
| current_stop = stop_event |
|
|
| |
| messages = [] |
| recent_history = history[-6:] if len(history) > 6 else history |
| for entry in recent_history: |
| messages.append({"role": entry["role"], "content": entry["content"]}) |
| messages.append({"role": "user", "content": message}) |
|
|
| input_text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False |
| ) |
|
|
| inputs = tokenizer([input_text], return_tensors="pt").to("cpu") |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, |
| timeout=60.0, |
| skip_prompt=True, |
| skip_special_tokens=True |
| ) |
|
|
| generate_kwargs = dict( |
| **inputs, |
| streamer=streamer, |
| max_new_tokens=512, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.8, |
| ) |
|
|
| t = Thread(target=model.generate, kwargs=generate_kwargs, daemon=True) |
| t.start() |
|
|
| partial_message = "" |
| try: |
| for new_token in streamer: |
| if current_stop.is_set(): |
| for _ in streamer: |
| pass |
| break |
| partial_message += new_token |
| yield partial_message |
| finally: |
| del inputs |
| gc.collect() |
|
|
|
|
| demo = gr.ChatInterface( |
| fn=chat_inference, |
| title="Indian History SLM", |
| description="Ask me anything about Indian History!", |
| stop_btn="Stop", |
| concurrency_limit=1, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |