File size: 4,556 Bytes
5f6ca02
beb3fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f6ca02
 
 
 
beb3fde
5f6ca02
 
 
 
 
 
beb3fde
5f6ca02
beb3fde
5f6ca02
beb3fde
 
 
 
 
 
5f6ca02
 
beb3fde
 
 
 
 
 
5f6ca02
beb3fde
 
 
 
 
 
 
 
 
5f6ca02
beb3fde
 
 
 
5f6ca02
 
 
 
beb3fde
5f6ca02
 
 
beb3fde
 
5f6ca02
beb3fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f6ca02
 
 
 
 
 
 
 
beb3fde
 
 
 
 
 
5f6ca02
 
beb3fde
 
 
 
5f6ca02
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import torch
import transformers
import os

# --- ๋ชจ๋ธ ์„ค์ • ---
# ์‚ฌ์šฉํ•  ๋ชจ๋ธ ID๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"

# --- ๋ชจ๋ธ ๋กœ๋”ฉ (Space๊ฐ€ ์‹œ์ž‘๋  ๋•Œ ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค) ---
print("๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค... ์ดˆ๊ธฐ ์‹คํ–‰ ์‹œ ์‹œ๊ฐ„์ด ๋‹ค์†Œ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
try:
    # 4๋น„ํŠธ ์–‘์žํ™”๋กœ VRAM ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ž…๋‹ˆ๋‹ค. (T4 GPU์—์„œ ์‹คํ–‰ ๊ฐ€๋Šฅ)
    model = transformers.AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16, # T4 GPU์™€ ํ˜ธํ™˜๋˜๋Š” ๋ฐ์ดํ„ฐ ํƒ€์ž…
        device_map="auto",          # ์ž๋™์œผ๋กœ GPU์— ํ• ๋‹น
        load_in_4bit=True,          # 4๋น„ํŠธ ์–‘์žํ™” ํ™œ์„ฑํ™”
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)
    
    # ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ฏธ๋ฆฌ ๋งŒ๋“ค์–ด ๋‘ก๋‹ˆ๋‹ค.
    text_generator = transformers.pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )
    print("โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
except Exception as e:
    print(f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
    # ๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋”๋ฏธ ํ•จ์ˆ˜๋กœ ๋Œ€์ฒด
    def text_generator(*args, **kwargs):
        yield "๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. Space์˜ ํ•˜๋“œ์›จ์–ด ์„ค์ •์„ ํ™•์ธํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ ์ด๋ฆ„์ด ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•ด์ฃผ์„ธ์š”."


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    ์‚ฌ์šฉ์ž์˜ ๋ฉ”์‹œ์ง€์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
    """
    # Qwen ๋ชจ๋ธ์ด ์š”๊ตฌํ•˜๋Š” ํ˜•์‹์œผ๋กœ ๋ฉ”์‹œ์ง€ ํฌ๋งทํŒ…
    messages = [{"role": "system", "content": system_message}]
    
    # Gradio์˜ history๋Š” [(user1, bot1), (user2, bot2)] ํ˜•ํƒœ
    for user_msg, bot_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": bot_msg})
        
    messages.append({"role": "user", "content": message})

    # ํ”„๋กฌํ”„ํŠธ๋ฅผ ํ† ํฌ๋‚˜์ด์ €์˜ ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ์— ๋งž๊ฒŒ ๋ณ€ํ™˜
    prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )

    # ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ๋‹ต๋ณ€ ์ƒ์„ฑ (์ŠคํŠธ๋ฆฌ๋ฐ)
    response = ""
    generation_args = {
        "max_new_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": True,
        "stream": True, # ์ŠคํŠธ๋ฆฌ๋ฐ์œผ๋กœ ์‹ค์‹œ๊ฐ„ ์‘๋‹ต
    }

    for chunk in text_generator(prompt, **generation_args):
        # ์ŠคํŠธ๋ฆฌ๋ฐ ์‘๋‹ต์—์„œ ์‹ค์ œ ํ…์ŠคํŠธ ๋ถ€๋ถ„๋งŒ ์ถ”์ถœ
        token = chunk[0]['generated_text'][len(prompt):]
        response = token
        yield response


"""
Gradio ChatInterface๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฑ—๋ด‡ UI๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
"""
chatbot = gr.ChatInterface(
    respond,
    type="messages", # Gradio 4.x ์ด์ƒ์˜ ์ตœ์‹  ๋ฉ”์‹œ์ง€ ํ˜•์‹ ์‚ฌ์šฉ
    additional_inputs_accordion="โš™๏ธ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •",
    additional_inputs=[
        gr.Textbox(
            value="You are Qwen2.5-Coder, created by Alibaba Cloud. You are a helpful assistant specialized in coding and programming.", 
            label="System message"
        ),
        gr.Slider(
            minimum=1, 
            maximum=4096, 
            value=1024, 
            step=1, 
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1, 
            maximum=4.0, 
            value=0.7, 
            step=0.1, 
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["PyTorch๋กœ ๊ฐ„๋‹จํ•œ CNN ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด์ค˜."],
        ["์ด ํŒŒ์ด์ฌ ์ฝ”๋“œ๋ฅผ ์ตœ์ ํ™”ํ•ด์ค˜:\n\n```python\nfor i in range(len(my_list)):\n    print(my_list[i])\n```"],
        ["FastAPI๋กœ 'hello world'๋ฅผ ์ถœ๋ ฅํ•˜๋Š” API ์—”๋“œํฌ์ธํŠธ๋ฅผ ๋งŒ๋“ค์–ด์ค˜."],
    ],
    cache_examples=False, # ์˜ˆ์ œ ์บ์‹ฑ ๋น„ํ™œ์„ฑํ™” (๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ)
)

# Gradio Blocks๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ ˆ์ด์•„์›ƒ ๊ตฌ์„ฑ
with gr.Blocks(theme=gr.themes.Soft(), title="๋‚˜๋งŒ์˜ AI ์ฝ”๋“œ ๋ฆฌ๋”") as demo:
    gr.Markdown("# ๐Ÿค– ๋‚˜๋งŒ์˜ AI ์ฝ”๋“œ ๋ฆฌ๋” (Qwen2.5-Coder)")
    gr.Markdown("์ด ์ฑ—๋ด‡์€ **Qwen2.5-Coder-7B-Instruct** ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ฝ”๋“œ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.")
    chatbot.render()

if __name__ == "__main__":
    demo.launch()