File size: 1,136 Bytes
d5b5f52
 
 
 
 
e94343a
 
84bfe0d
d5b5f52
 
 
 
 
 
 
 
 
 
 
 
 
 
efc1f81
d5b5f52
 
 
 
 
 
 
 
 
 
 
 
 
 
e94343a
d5b5f52
e94343a
d5b5f52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# import python-dotenv
import os
from dotenv import load_dotenv

# import from huggingface
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# import regex for clean response
import re

# import gradio for gui
import gradio as gr

# NO TOKEN NEEDED
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
dtype = torch.bfloat16

def tinyllama_chat(message, history):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=None, 
        torch_dtype=dtype,
    )

    chat = [
        {"role": "user", "content": message},
    ]

    prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

    inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=2048)

    response = tokenizer.decode(outputs[0])
    response_cleaned = re.split("assistant", response, flags=re.IGNORECASE)

    return response_cleaned[-1]

gr.ChatInterface(tinyllama_chat).launch()