File size: 3,724 Bytes
66cc82f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b165e5
5721c75
 
0b165e5
5721c75
 
66cc82f
0b165e5
5721c75
 
 
 
 
 
0b165e5
5721c75
 
 
 
 
 
 
 
 
 
 
 
0b165e5
5721c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b165e5
5721c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b165e5
5721c75
0b165e5
5721c75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import subprocess
import sys

# ν•„μš”ν•œ νŒ¨ν‚€μ§€ μžλ™ μ„€μΉ˜
def install_packages():
    packages = [
        "transformers==4.45.0",
        "torch",
        "accelerate",
        "sentencepiece"
    ]
    for package in packages:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

print("ν•„μš”ν•œ νŒ¨ν‚€μ§€ μ„€μΉ˜ 쀑...")
install_packages()
print("νŒ¨ν‚€μ§€ μ„€μΉ˜ μ™„λ£Œ!")

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# λͺ¨λΈ λ‘œλ”©
print("λͺ¨λΈ λ‘œλ”© 쀑...")
model_name = "microsoft/Phi-3-mini-4k-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True
)
print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ!")

def format_chat_prompt(message, chat_history):
    """μ±„νŒ… νžˆμŠ€ν† λ¦¬λ₯Ό ν”„λ‘¬ν”„νŠΈλ‘œ λ³€ν™˜"""
    prompt = ""
    for user_msg, assistant_msg in chat_history:
        if user_msg:
            prompt += f"User: {user_msg}\n"
        if assistant_msg:
            prompt += f"Assistant: {assistant_msg}\n"
    prompt += f"User: {message}\nAssistant:"
    return prompt

def chat(message, history):
    """μ±„νŒ… 응닡 생성"""
    # ν”„λ‘¬ν”„νŠΈ 생성
    prompt = format_chat_prompt(message, history)
    
    # ν† ν¬λ‚˜μ΄μ§•
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    if torch.cuda.is_available():
        inputs = inputs.to("cuda")
    
    # 응닡 생성
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # 응닡 λ””μ½”λ”©
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    
    # λΆˆν•„μš”ν•œ λΆ€λΆ„ 제거
    response = response.split("User:")[0].strip()
    
    return response

# Gradio μΈν„°νŽ˜μ΄μŠ€ 생성
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # πŸ€– AI μ±„νŒ… μ„œλΉ„μŠ€
        ### Phi-3-mini 4B λͺ¨λΈ 기반 λŒ€ν™”ν˜• AI
        자유둭게 μ§ˆλ¬Έν•˜κ³  λŒ€ν™”ν•΄λ³΄μ„Έμš”!
        """
    )
    
    chatbot = gr.Chatbot(
        height=500,
        bubble_full_width=False,
        avatar_images=(None, "πŸ€–"),
    )
    
    with gr.Row():
        msg = gr.Textbox(
            label="λ©”μ‹œμ§€ μž…λ ₯",
            placeholder="λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...",
            scale=4,
            container=False
        )
        submit = gr.Button("전솑", scale=1, variant="primary")
    
    with gr.Row():
        clear = gr.Button("λŒ€ν™” μ΄ˆκΈ°ν™”")
    
    gr.Examples(
        examples=[
            "μ•ˆλ…•ν•˜μ„Έμš”! μžκΈ°μ†Œκ°œ λΆ€νƒλ“œλ €μš”.",
            "Python으둜 κ°„λ‹¨ν•œ 계산기 λ§Œλ“œλŠ” 방법 μ•Œλ €μ€˜",
            "였늘의 λͺ…μ–Έ ν•˜λ‚˜ λ“€λ €μ€˜",
            "κΈ°λΆ„ μ’‹μ•„μ§€λŠ” 농담 ν•΄μ€˜",
        ],
        inputs=msg,
        label="μ˜ˆμ‹œ 질문"
    )
    
    # 이벀트 ν•Έλ“€λŸ¬
    def respond(message, chat_history):
        bot_message = chat(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history
    
    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    submit.click(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

# μ•± μ‹€ν–‰
if __name__ == "__main__":
    demo.queue().launch()