File size: 1,667 Bytes
c82b2ab
e55a249
 
c82b2ab
e55a249
 
c82b2ab
e55a249
 
 
 
c82b2ab
e55a249
 
 
 
9f63a22
 
e55a249
 
9f63a22
 
 
e55a249
9f63a22
e55a249
 
 
 
 
 
9f63a22
e55a249
 
c82b2ab
e55a249
9f63a22
 
 
e55a249
c82b2ab
e55a249
c82b2ab
 
 
 
e55a249
 
c82b2ab
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "ibm-granite/granite-4.0-h-350M"

# Model ve tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
model.eval()

def respond(message, history, max_new_tokens, temperature):
    """
    history: önceki mesajlar listesi
    """
    history = history or []

    # Mesaj geçmişini chat formatına çevir
    chat = []
    for h in history:
        if h["role"] == "user":
            chat.append({"role": "user", "content": h["content"]})
    chat.append({"role": "user", "content": message})

    chat_text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    input_tokens = tokenizer(chat_text, return_tensors="pt").to(device)

    # Yanıt üretimi
    output_tokens = model.generate(
        **input_tokens,
        max_new_tokens=max_new_tokens
    )
    output_text = tokenizer.batch_decode(output_tokens)[0]

    # History güncelle
    history.append({"role": "user", "content": message})
    history.append({"role": "assistant", "content": output_text})

    return output_text, history

# Gradio chat interface
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Slider(minimum=1, maximum=1024, value=200, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature"),
    ],
)

with gr.Blocks() as demo:
    chatbot.render()

if __name__ == "__main__":
    demo.launch()