kf-deberta-gen / app.py
solon's picture
Remove AI example question
3f1ad1b
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
# ์„ค์ •
MODEL_ID = "solonsophy/kf-deberta-gen" # ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ
BASE_MODEL_ID = "kakaobank/kf-deberta-base" # ๊ธฐ๋ฐ˜ ๋ชจ๋ธ (ํ† ํฌ๋‚˜์ด์ €์šฉ)
MAX_LEN = 256
Q_MAX_LEN = 100
# ๋ชจ๋ธ ๋กœ๋“œ
print("๐Ÿ”„ Loading model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) # ๊ธฐ๋ฐ˜ ๋ชจ๋ธ์—์„œ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID) # ํŒŒ์ธํŠœ๋‹๋œ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
print(f"โœ… Model loaded on {device}")
MASK_ID = tokenizer.mask_token_id
PAD_ID = tokenizer.pad_token_id
CLS_ID = tokenizer.cls_token_id
SEP_ID = tokenizer.sep_token_id
def generate_response(question, num_steps, temperature, top_k, max_answer_len):
"""Diffusion ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€ ์ƒ์„ฑ"""
if not question.strip():
return "์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
# ์งˆ๋ฌธ ํ† ํฐํ™”
q_tokens = tokenizer.encode(question, add_special_tokens=False)[:Q_MAX_LEN]
# ์ดˆ๊ธฐ: [CLS] Q [SEP] [MASK]*N
input_ids = [CLS_ID] + q_tokens + [SEP_ID] + [MASK_ID] * max_answer_len
input_ids = input_ids[:MAX_LEN]
answer_start = len(q_tokens) + 2
answer_end = len(input_ids)
input_ids = torch.tensor([input_ids], device=device)
attention_mask = torch.ones_like(input_ids)
# Iterative denoising
for step in range(num_steps):
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# ๋งˆ์Šคํฌ ์œ„์น˜ ์ฐพ๊ธฐ
mask_positions = (input_ids[0, answer_start:answer_end] == MASK_ID).nonzero(as_tuple=True)[0]
mask_positions = mask_positions + answer_start
if len(mask_positions) == 0:
break
# ์ด๋ฒˆ ์Šคํ…์—์„œ unmaskํ•  ๊ฐœ์ˆ˜
remaining_steps = num_steps - step
tokens_per_step = max(1, len(mask_positions) // remaining_steps)
# logits ์ฒ˜๋ฆฌ
mask_logits = logits[0, mask_positions] / temperature
# Top-k filtering
if top_k > 0:
top_k_values, _ = torch.topk(mask_logits, min(top_k, mask_logits.size(-1)), dim=-1)
threshold = top_k_values[:, -1].unsqueeze(-1)
mask_logits = torch.where(mask_logits < threshold, float('-inf'), mask_logits)
# ์ƒ˜ํ”Œ๋ง
probs = F.softmax(mask_logits, dim=-1)
sampled_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
# Confidence
confidences = probs.gather(1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
# Confidence ๊ธฐ๋ฐ˜ unmask
_, top_indices = torch.topk(confidences, min(tokens_per_step, len(confidences)))
selected_positions = mask_positions[top_indices]
selected_tokens = sampled_tokens[top_indices]
input_ids[0, selected_positions] = selected_tokens
# ๊ฒฐ๊ณผ ์ถ”์ถœ
answer_tokens = input_ids[0, answer_start:answer_end]
valid_mask = (answer_tokens != MASK_ID) & (answer_tokens != PAD_ID)
answer_tokens = answer_tokens[valid_mask]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
return answer.strip() if answer.strip() else "(์ƒ์„ฑ ์‹คํŒจ)"
# Gradio UI
with gr.Blocks(title="kf-deberta-gen Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐ŸŒ€ kf-deberta-gen Demo
**Generative Diffusion BERT** - ํ•œ๊ตญ์–ด Diffusion ๊ธฐ๋ฐ˜ ์ƒ์„ฑ ์–ธ์–ด ๋ชจ๋ธ (์‹คํ—˜์ )
> โš ๏ธ ์ด ๋ชจ๋ธ์€ PoC ๋‹จ๊ณ„์ž…๋‹ˆ๋‹ค. ์ƒ์„ฑ ํ’ˆ์งˆ์ด ๋ถˆ์•ˆ์ •ํ•˜๋ฉฐ ๋ฐ˜๋ณต ์ƒ์„ฑ ๋“ฑ์˜ ๋ฌธ์ œ๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column(scale=2):
question_input = gr.Textbox(
label="์งˆ๋ฌธ",
placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...",
lines=2
)
submit_btn = gr.Button("๐Ÿš€ ์ƒ์„ฑ", variant="primary")
with gr.Column(scale=1):
num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="Temperature")
top_k = gr.Slider(1, 50, value=10, step=1, label="Top-K")
max_len = gr.Slider(20, 150, value=80, step=10, label="Max Answer Length")
output = gr.Textbox(label="๋‹ต๋ณ€", lines=5)
gr.Examples(
examples=[
["์˜ค๋Š˜ ๋‚ ์”จ ์–ด๋•Œ?"],
["ํŒŒ์ด์ฌ์„ ๋ฐฐ์šฐ๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ•˜๋‚˜์š”?"],
["์•ˆ๋…•ํ•˜์„ธ์š”"],
],
inputs=question_input
)
submit_btn.click(
fn=generate_response,
inputs=[question_input, num_steps, temperature, top_k, max_len],
outputs=output
)
question_input.submit(
fn=generate_response,
inputs=[question_input, num_steps, temperature, top_k, max_len],
outputs=output
)
if __name__ == "__main__":
demo.launch()