FebPrompt / app.py
dawon5177's picture
Update app.py
beb3fde verified
import gradio as gr
import torch
import transformers
import os
# --- ๋ชจ๋ธ ์„ค์ • ---
# ์‚ฌ์šฉํ•  ๋ชจ๋ธ ID๋ฅผ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"
# --- ๋ชจ๋ธ ๋กœ๋”ฉ (Space๊ฐ€ ์‹œ์ž‘๋  ๋•Œ ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค) ---
print("๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค... ์ดˆ๊ธฐ ์‹คํ–‰ ์‹œ ์‹œ๊ฐ„์ด ๋‹ค์†Œ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
try:
# 4๋น„ํŠธ ์–‘์žํ™”๋กœ VRAM ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ž…๋‹ˆ๋‹ค. (T4 GPU์—์„œ ์‹คํ–‰ ๊ฐ€๋Šฅ)
model = transformers.AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # T4 GPU์™€ ํ˜ธํ™˜๋˜๋Š” ๋ฐ์ดํ„ฐ ํƒ€์ž…
device_map="auto", # ์ž๋™์œผ๋กœ GPU์— ํ• ๋‹น
load_in_4bit=True, # 4๋น„ํŠธ ์–‘์žํ™” ํ™œ์„ฑํ™”
)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)
# ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ์„ ๋ฏธ๋ฆฌ ๋งŒ๋“ค์–ด ๋‘ก๋‹ˆ๋‹ค.
text_generator = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
print("โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
except Exception as e:
print(f"โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
# ๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” ๋”๋ฏธ ํ•จ์ˆ˜๋กœ ๋Œ€์ฒด
def text_generator(*args, **kwargs):
yield "๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ๋ฐ ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. Space์˜ ํ•˜๋“œ์›จ์–ด ์„ค์ •์„ ํ™•์ธํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ ์ด๋ฆ„์ด ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
์‚ฌ์šฉ์ž์˜ ๋ฉ”์‹œ์ง€์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
"""
# Qwen ๋ชจ๋ธ์ด ์š”๊ตฌํ•˜๋Š” ํ˜•์‹์œผ๋กœ ๋ฉ”์‹œ์ง€ ํฌ๋งทํŒ…
messages = [{"role": "system", "content": system_message}]
# Gradio์˜ history๋Š” [(user1, bot1), (user2, bot2)] ํ˜•ํƒœ
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# ํ”„๋กฌํ”„ํŠธ๋ฅผ ํ† ํฌ๋‚˜์ด์ €์˜ ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ์— ๋งž๊ฒŒ ๋ณ€ํ™˜
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ๋‹ต๋ณ€ ์ƒ์„ฑ (์ŠคํŠธ๋ฆฌ๋ฐ)
response = ""
generation_args = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"stream": True, # ์ŠคํŠธ๋ฆฌ๋ฐ์œผ๋กœ ์‹ค์‹œ๊ฐ„ ์‘๋‹ต
}
for chunk in text_generator(prompt, **generation_args):
# ์ŠคํŠธ๋ฆฌ๋ฐ ์‘๋‹ต์—์„œ ์‹ค์ œ ํ…์ŠคํŠธ ๋ถ€๋ถ„๋งŒ ์ถ”์ถœ
token = chunk[0]['generated_text'][len(prompt):]
response = token
yield response
"""
Gradio ChatInterface๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฑ—๋ด‡ UI๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
"""
chatbot = gr.ChatInterface(
respond,
type="messages", # Gradio 4.x ์ด์ƒ์˜ ์ตœ์‹  ๋ฉ”์‹œ์ง€ ํ˜•์‹ ์‚ฌ์šฉ
additional_inputs_accordion="โš™๏ธ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •",
additional_inputs=[
gr.Textbox(
value="You are Qwen2.5-Coder, created by Alibaba Cloud. You are a helpful assistant specialized in coding and programming.",
label="System message"
),
gr.Slider(
minimum=1,
maximum=4096,
value=1024,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["PyTorch๋กœ ๊ฐ„๋‹จํ•œ CNN ๋ชจ๋ธ์„ ๋งŒ๋“ค์–ด์ค˜."],
["์ด ํŒŒ์ด์ฌ ์ฝ”๋“œ๋ฅผ ์ตœ์ ํ™”ํ•ด์ค˜:\n\n```python\nfor i in range(len(my_list)):\n print(my_list[i])\n```"],
["FastAPI๋กœ 'hello world'๋ฅผ ์ถœ๋ ฅํ•˜๋Š” API ์—”๋“œํฌ์ธํŠธ๋ฅผ ๋งŒ๋“ค์–ด์ค˜."],
],
cache_examples=False, # ์˜ˆ์ œ ์บ์‹ฑ ๋น„ํ™œ์„ฑํ™” (๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ)
)
# Gradio Blocks๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ ˆ์ด์•„์›ƒ ๊ตฌ์„ฑ
with gr.Blocks(theme=gr.themes.Soft(), title="๋‚˜๋งŒ์˜ AI ์ฝ”๋“œ ๋ฆฌ๋”") as demo:
gr.Markdown("# ๐Ÿค– ๋‚˜๋งŒ์˜ AI ์ฝ”๋“œ ๋ฆฌ๋” (Qwen2.5-Coder)")
gr.Markdown("์ด ์ฑ—๋ด‡์€ **Qwen2.5-Coder-7B-Instruct** ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ฝ”๋“œ๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.")
chatbot.render()
if __name__ == "__main__":
demo.launch()