File size: 3,970 Bytes
1598dba
30ce988
1598dba
 
 
30ce988
1598dba
 
 
f981207
 
1598dba
 
 
 
 
 
 
 
 
b36323a
1598dba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b36323a
 
043d47a
b36323a
1598dba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fc3510
 
 
 
 
 
 
1598dba
b36323a
b7f8976
1598dba
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from time import sleep
import streamlit as st
# for GPU inference, uncomment the following line
# from unsloth import FastLanguageModel, is_bfloat16_supported
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

AI_MODE = "ON"

if AI_MODE == "ON":
    model_id = "choco-conoz/TwinLlama-3.2-1B-DPO"
    # model_id = "choco-conoz/TwinLlama-3.2-1B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    # for GPU inference, uncomment the following line
    # model = FastLanguageModel.for_inference(model)

    processor = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=256
    )

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids(""),
    ]


def main():
    st.title('DEMO - DPO')
    st.subheader('Instruction/Response')
    st.markdown('<div style="text-align: right;">produced by Conoz (https://www.conoz.com)</div>',
                unsafe_allow_html=True)
    st.markdown(
        '<div><br />basic space hardware์—์„œ ์‘๋‹ต์‹œ๊ฐ„์€ 3๋ถ„ ์ •๋„ ์†Œ์š”๋ฉ๋‹ˆ๋‹ค. '
        '์˜์–ด, ํ•œ๊ตญ์–ด ๋“ฑ์œผ๋กœ ์งˆ๋ฌธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.<br />'
        '์ฝ”๋…ธ์ฆˆ์—์„œ Llama-3.2-1B model์„ DPO๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. '
        '์•ŒํŒŒ์นด chat template์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.<br />'
        '์ฝ”๋…ธ์ฆˆ์—์„œ Llama-3.1-8B ๋ชจ๋ธ์„ DPO๋กœ ํ•™์Šตํ•œ ๋ชจ๋ธ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์ง€๋งŒ basic space hardware ์—์„  ๋™์ž‘ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.</div>',
        unsafe_allow_html=True
    )
    st.markdown(
        '<div>Response time on basic space hardware takes about 3 minutes. '
        'You can ask questions in English, Korean, etc. '
        'It is a model fine-tuned on the Llama-3.2-1B model by Conoz. '
        'It uses the Alpaca chat template.<br />'
        'You can also use the model fine-tuned on the Llama-3.1-8B model by Conoz, but it does not work on the basic space hardware.<br /></div>',
        unsafe_allow_html=True
    )
    st.markdown('<hr />', unsafe_allow_html=True)
    query = st.text_input('input your topic of interest. (10 ~ 1000 characters)',
                          placeholder='e.g. What is the capital of South Korea?')

    alpaca_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.\
    ## Instruction:
    {}
    ## Response:
    """

    if st.button("Send"):
        if not query:
            st.error("Please enter a query.")
            return
        if len(query) < 10:
            st.error("Please enter a query with at least 10 characters.")
            return
        if len(query) > 1000:
            st.error("Please enter a query with less than 1000 characters.")
            return
        with st.spinner("Generating response..."):
            user_prompt = alpaca_template.format(query, "")
            if AI_MODE == "ON":
                # for chat models
                # tokenizer.chat_template = {
                #     "role": "user",
                #     "prompt": user_prompt,
                #     "generation_prompt": "",
                # }
                # user_prompt = tokenizer.apply_chat_template(
                # user_prompt, tokenize=False, add_generation_prompt=True)
                outputs = processor(user_prompt,
                                    max_new_tokens=256,
                                    # num_return_sequences=1,
                                    temperature=0.6,
                                    top_p=0.9,
                                    )
                # eos_token_id=terminators,
                response = outputs[0]["generated_text"][len(user_prompt):]
            else:
                sleep(3)
                response = "AI_MODE is OFF. Please turn it ON to get a response."
            st.subheader('Response:')
        st.write(response)


if __name__ == "__main__":
    main()