Spaces:
Running
Running
File size: 5,139 Bytes
19d03cc 309d106 19d03cc 309d106 19d03cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
# ์ค์
MODEL_ID = "solonsophy/kf-deberta-gen" # ํ์ธํ๋๋ ๋ชจ๋ธ
BASE_MODEL_ID = "kakaobank/kf-deberta-base" # ๊ธฐ๋ฐ ๋ชจ๋ธ (ํ ํฌ๋์ด์ ์ฉ)
MAX_LEN = 256
Q_MAX_LEN = 100
# ๋ชจ๋ธ ๋ก๋
print("๐ Loading model...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) # ๊ธฐ๋ฐ ๋ชจ๋ธ์์ ํ ํฌ๋์ด์ ๋ก๋
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID) # ํ์ธํ๋๋ ๊ฐ์ค์น ๋ก๋
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
print(f"โ
Model loaded on {device}")
MASK_ID = tokenizer.mask_token_id
PAD_ID = tokenizer.pad_token_id
CLS_ID = tokenizer.cls_token_id
SEP_ID = tokenizer.sep_token_id
def generate_response(question, num_steps, temperature, top_k, max_answer_len):
"""Diffusion ๊ธฐ๋ฐ ๋ต๋ณ ์์ฑ"""
if not question.strip():
return "์ง๋ฌธ์ ์
๋ ฅํด์ฃผ์ธ์."
# ์ง๋ฌธ ํ ํฐํ
q_tokens = tokenizer.encode(question, add_special_tokens=False)[:Q_MAX_LEN]
# ์ด๊ธฐ: [CLS] Q [SEP] [MASK]*N
input_ids = [CLS_ID] + q_tokens + [SEP_ID] + [MASK_ID] * max_answer_len
input_ids = input_ids[:MAX_LEN]
answer_start = len(q_tokens) + 2
answer_end = len(input_ids)
input_ids = torch.tensor([input_ids], device=device)
attention_mask = torch.ones_like(input_ids)
# Iterative denoising
for step in range(num_steps):
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# ๋ง์คํฌ ์์น ์ฐพ๊ธฐ
mask_positions = (input_ids[0, answer_start:answer_end] == MASK_ID).nonzero(as_tuple=True)[0]
mask_positions = mask_positions + answer_start
if len(mask_positions) == 0:
break
# ์ด๋ฒ ์คํ
์์ unmaskํ ๊ฐ์
remaining_steps = num_steps - step
tokens_per_step = max(1, len(mask_positions) // remaining_steps)
# logits ์ฒ๋ฆฌ
mask_logits = logits[0, mask_positions] / temperature
# Top-k filtering
if top_k > 0:
top_k_values, _ = torch.topk(mask_logits, min(top_k, mask_logits.size(-1)), dim=-1)
threshold = top_k_values[:, -1].unsqueeze(-1)
mask_logits = torch.where(mask_logits < threshold, float('-inf'), mask_logits)
# ์ํ๋ง
probs = F.softmax(mask_logits, dim=-1)
sampled_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
# Confidence
confidences = probs.gather(1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
# Confidence ๊ธฐ๋ฐ unmask
_, top_indices = torch.topk(confidences, min(tokens_per_step, len(confidences)))
selected_positions = mask_positions[top_indices]
selected_tokens = sampled_tokens[top_indices]
input_ids[0, selected_positions] = selected_tokens
# ๊ฒฐ๊ณผ ์ถ์ถ
answer_tokens = input_ids[0, answer_start:answer_end]
valid_mask = (answer_tokens != MASK_ID) & (answer_tokens != PAD_ID)
answer_tokens = answer_tokens[valid_mask]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
return answer.strip() if answer.strip() else "(์์ฑ ์คํจ)"
# Gradio UI
with gr.Blocks(title="kf-deberta-gen Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐ kf-deberta-gen Demo
**Generative Diffusion BERT** - ํ๊ตญ์ด Diffusion ๊ธฐ๋ฐ ์์ฑ ์ธ์ด ๋ชจ๋ธ (์คํ์ )
> โ ๏ธ ์ด ๋ชจ๋ธ์ PoC ๋จ๊ณ์
๋๋ค. ์์ฑ ํ์ง์ด ๋ถ์์ ํ๋ฉฐ ๋ฐ๋ณต ์์ฑ ๋ฑ์ ๋ฌธ์ ๊ฐ ์์ ์ ์์ต๋๋ค.
""")
with gr.Row():
with gr.Column(scale=2):
question_input = gr.Textbox(
label="์ง๋ฌธ",
placeholder="์ง๋ฌธ์ ์
๋ ฅํ์ธ์...",
lines=2
)
submit_btn = gr.Button("๐ ์์ฑ", variant="primary")
with gr.Column(scale=1):
num_steps = gr.Slider(10, 100, value=50, step=5, label="Steps")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="Temperature")
top_k = gr.Slider(1, 50, value=10, step=1, label="Top-K")
max_len = gr.Slider(20, 150, value=80, step=10, label="Max Answer Length")
output = gr.Textbox(label="๋ต๋ณ", lines=5)
gr.Examples(
examples=[
["์ค๋ ๋ ์จ ์ด๋?"],
["ํ์ด์ฌ์ ๋ฐฐ์ฐ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?"],
["์๋
ํ์ธ์"],
],
inputs=question_input
)
submit_btn.click(
fn=generate_response,
inputs=[question_input, num_steps, temperature, top_k, max_len],
outputs=output
)
question_input.submit(
fn=generate_response,
inputs=[question_input, num_steps, temperature, top_k, max_len],
outputs=output
)
if __name__ == "__main__":
demo.launch()
|