File size: 2,369 Bytes
9aed480
 
 
fd449f8
c0a8087
9aed480
 
4185a6c
c0741fa
c0a8087
9aed480
 
c0741fa
c0a8087
4185a6c
 
 
 
c0a8087
 
6a1eaed
9aed480
c0a8087
fd449f8
c0741fa
 
82cd61c
c0741fa
 
 
 
5acdaf9
c0741fa
5acdaf9
254de85
5acdaf9
 
 
254de85
c0741fa
254de85
fd449f8
c0a8087
4185a6c
c0741fa
c0a8087
c0741fa
9aed480
c0a8087
c0741fa
 
 
 
 
 
254de85
 
 
 
c0741fa
c0a8087
254de85
4185a6c
c0a8087
 
c0741fa
254de85
 
 
c0741fa
 
 
 
 
 
 
 
 
 
 
 
c0a8087
82cd61c
c0a8087
82cd61c
a94700b
5acdaf9
c0741fa
fd449f8
 
 
c0a8087
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
os.environ["OMP_NUM_THREADS"] = "2"

import gradio as gr
import torch
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import hf_hub_download
from threading import Thread, Event

torch.set_num_threads(2)

model_path = "ruhzi/Indian_History_SLM"
tokenizer = AutoTokenizer.from_pretrained(model_path)
template_file = hf_hub_download(repo_id=model_path, filename="chat_template.jinja")
with open(template_file, "r", encoding="utf-8") as f:
    tokenizer.chat_template = f.read()

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    dtype=torch.float32,
    low_cpu_mem_usage=True
)

stop_event = Event()

def chat_inference(message, history):
    global stop_event

    stop_event.set()
    stop_event = Event()
    current_stop = stop_event

    # Gradio 6.x: history is a list of {"role": "...", "content": "..."} dicts
    messages = []
    recent_history = history[-6:] if len(history) > 6 else history  # 6 = 3 turns
    for entry in recent_history:
        messages.append({"role": entry["role"], "content": entry["content"]})
    messages.append({"role": "user", "content": message})

    input_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )

    inputs = tokenizer([input_text], return_tensors="pt").to("cpu")

    streamer = TextIteratorStreamer(
        tokenizer,
        timeout=60.0,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_p=0.8,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs, daemon=True)
    t.start()

    partial_message = ""
    try:
        for new_token in streamer:
            if current_stop.is_set():
                for _ in streamer:
                    pass
                break
            partial_message += new_token
            yield partial_message
    finally:
        del inputs
        gc.collect()


demo = gr.ChatInterface(
    fn=chat_inference,
    title="Indian History SLM",
    description="Ask me anything about Indian History!",
    stop_btn="Stop",
    concurrency_limit=1,
)

if __name__ == "__main__":
    demo.launch()