File size: 1,599 Bytes
2d02001
 
1003d43
 
2d02001
 
 
 
 
 
fb77952
2d02001
 
b7b6a81
fb77952
 
 
 
2d02001
 
fb77952
 
 
 
2d02001
 
 
fb77952
2d02001
 
 
 
 
 
b7b6a81
fb77952
 
 
 
 
 
b7b6a81
fb77952
 
 
2d02001
b7b6a81
fb77952
2d02001
b7b6a81
fb77952
2d02001
 
 
fb77952
2d02001
 
 
fb77952
 
 
 
b7b6a81
fb77952
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "LivSterling/rc-tutor-llama3-merged"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype=torch.float16,
)

# LLaMA-3 padding fix
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id


def chat_fn(message, history):
    messages = []

    for user, bot in history:
        messages.append({"role": "user", "content": user})
        messages.append({"role": "assistant", "content": bot})

    messages.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    ).to(model.device)

    attention_mask = input_ids != tokenizer.pad_token_id

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=256,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    return tokenizer.decode(
        outputs[0][input_ids.shape[-1]:],
        skip_special_tokens=True,
    ).strip()


demo = gr.ChatInterface(
    fn=chat_fn,
    title="RC Tutor LLaMA-3",
)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        ssr_mode=False,
        show_error=True,
    )