File size: 3,754 Bytes
8f136d4
d5f4c16
 
9c878f1
 
8f136d4
d5f4c16
e89521f
d5f4c16
 
9c878f1
d5f4c16
 
9c878f1
d5f4c16
 
 
 
8f136d4
 
 
 
 
 
 
 
 
 
d5f4c16
8f136d4
d5f4c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f136d4
9c878f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f136d4
 
 
 
 
 
 
d5f4c16
 
 
8f136d4
 
 
 
 
 
 
 
 
 
 
d5f4c16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextIteratorStreamer
from threading import Thread

# Load model and tokenizer at startup
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
print(f"Loading model {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True
)
print("Model loaded successfully!")

def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Generate responses using the local Anki 2.5 model.
    """
    # Build conversation history
    conversation = []
    
    # Add system message if provided
    if system_message:
        conversation.append({"role": "system", "content": system_message})
    
    # Add chat history
    for msg in history:
        conversation.append(msg)
    
    # Add current message
    conversation.append({"role": "user", "content": message})
    
    # Format prompt for the model
    # Try to apply chat template if available, otherwise use simple format
    try:
        formatted_prompt = tokenizer.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=True
        )
    except:
        # Fallback to simple format if chat template not available
        formatted_prompt = ""
        for msg in conversation:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            if role == "system":
                formatted_prompt += f"System: {content}\n"
            elif role == "user":
                formatted_prompt += f"User: {content}\n"
            elif role == "assistant":
                formatted_prompt += f"Assistant: {content}\n"
        formatted_prompt += "Assistant: "
    
    # Tokenize input
    inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate response with streaming
    response = ""
        # Create streamer for batch generation
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
    
    # Set up generation parameters
    generation_kwargs = {
        **inputs,
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": True,
        "no_repeat_ngram_size": 3,
        "repetition_penalty": 1.1,
        "pad_token_id": tokenizer.eos_token_id,
        "streamer": streamer,
    }
    
    # Run generation in separate thread for streaming
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Stream response
    response = ""
    for new_text in streamer:
        response += new_text
        yield response
    
    thread.join()
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot proficient in Indian languages.", label="System message"),
        gr.Slider(minimum=1, maximum=512, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    chatbot.launch()