|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import time |
|
|
import numpy as np |
|
|
|
|
|
class MergedModelTester: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def load_model(self, model_id="openfree/gpt2-bert", progress=gr.Progress()): |
|
|
"""๋ณํฉ ๋ชจ๋ธ ๋ก๋""" |
|
|
try: |
|
|
progress(0.2, desc="ํ ํฌ๋์ด์ ๋ก๋ ์ค...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
progress(0.5, desc="๋ชจ๋ธ ๋ก๋ ์ค...") |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32, |
|
|
device_map="auto" if self.device.type == 'cuda' else None |
|
|
) |
|
|
|
|
|
if self.device.type == 'cpu': |
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
progress(1.0, desc="์๋ฃ!") |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in self.model.parameters()) |
|
|
return f"""โ
๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต! |
|
|
- ๋ชจ๋ธ: {model_id} |
|
|
- ํ๋ผ๋ฏธํฐ: {num_params:,} |
|
|
- ๋๋ฐ์ด์ค: {self.device}""" |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ ๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}" |
|
|
|
|
|
def generate_text(self, prompt, max_length=100, temperature=0.8, |
|
|
top_p=0.9, repetition_penalty=1.2, progress=gr.Progress()): |
|
|
"""ํ
์คํธ ์์ฑ""" |
|
|
if self.model is None: |
|
|
return "๋จผ์ ๋ชจ๋ธ์ ๋ก๋ํ์ธ์!", None, None |
|
|
|
|
|
try: |
|
|
progress(0.3, desc="ํ
์คํธ ์์ฑ ์ค...") |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_length, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=True, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
|
|
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
input_tokens = len(inputs['input_ids'][0]) |
|
|
output_tokens = len(outputs[0]) |
|
|
new_tokens = output_tokens - input_tokens |
|
|
|
|
|
stats = f"""๐ ์์ฑ ํต๊ณ: |
|
|
- ์
๋ ฅ ํ ํฐ: {input_tokens} |
|
|
- ์์ฑ ํ ํฐ: {new_tokens} |
|
|
- ์ ์ฒด ํ ํฐ: {output_tokens} |
|
|
- ์์ฑ ์๊ฐ: {generation_time:.2f}์ด |
|
|
- ์๋: {new_tokens/generation_time:.1f} tokens/sec""" |
|
|
|
|
|
progress(1.0, desc="์๋ฃ!") |
|
|
|
|
|
return generated_text, stats, None |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ ์์ฑ ์คํจ: {str(e)}", None, str(e) |
|
|
|
|
|
def compare_with_parents(self, prompt, max_length=50, progress=gr.Progress()): |
|
|
"""๋ถ๋ชจ ๋ชจ๋ธ๋ค๊ณผ ๋น๊ต""" |
|
|
results = {} |
|
|
|
|
|
|
|
|
try: |
|
|
progress(0.1, desc="GPT-2 ๋ก๋ ์ค...") |
|
|
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token |
|
|
gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2").to(self.device) |
|
|
|
|
|
progress(0.3, desc="GPT-2 ์์ฑ ์ค...") |
|
|
inputs = gpt2_tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
outputs = gpt2_model.generate(**inputs, max_new_tokens=max_length, do_sample=True) |
|
|
results['gpt2'] = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
del gpt2_model |
|
|
|
|
|
except Exception as e: |
|
|
results['gpt2'] = f"๋ก๋ ์คํจ: {str(e)}" |
|
|
|
|
|
|
|
|
results['bert'] = "BERT๋ ์์ฑ ๋ชจ๋ธ์ด ์๋๋๋ค (์ธ์ฝ๋ ์ ์ฉ)" |
|
|
|
|
|
|
|
|
try: |
|
|
progress(0.6, desc="๋ณํฉ ๋ชจ๋ธ ์์ฑ ์ค...") |
|
|
if self.model is None: |
|
|
self.load_model() |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate(**inputs, max_new_tokens=max_length, do_sample=True) |
|
|
results['merged'] = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
except Exception as e: |
|
|
results['merged'] = f"์์ฑ ์คํจ: {str(e)}" |
|
|
|
|
|
progress(1.0, desc="์๋ฃ!") |
|
|
|
|
|
|
|
|
comparison = f"""๐ ๋ชจ๋ธ ๋น๊ต ๊ฒฐ๊ณผ: |
|
|
|
|
|
**GPT-2 (๋ถ๋ชจ 1):** |
|
|
{results['gpt2']} |
|
|
|
|
|
**BERT (๋ถ๋ชจ 2):** |
|
|
{results['bert']} |
|
|
|
|
|
**๋ณํฉ ๋ชจ๋ธ (openfree/gpt2-bert):** |
|
|
{results['merged']}""" |
|
|
|
|
|
return comparison |
|
|
|
|
|
|
|
|
tester = MergedModelTester() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="GPT2-BERT ๋ณํฉ ๋ชจ๋ธ ํ
์คํฐ") as demo: |
|
|
gr.Markdown(""" |
|
|
# ๐งฌ GPT2-BERT ๋ณํฉ ๋ชจ๋ธ ํ
์คํฐ |
|
|
|
|
|
์งํ์ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ๋ณํฉ๋ [openfree/gpt2-bert](https://huggingface.co/openfree/gpt2-bert) ๋ชจ๋ธ์ ํ
์คํธํฉ๋๋ค. |
|
|
|
|
|
## ๐ ๋ชจ๋ธ ์ ๋ณด |
|
|
- **๋ถ๋ชจ 1**: openai-community/gpt2 |
|
|
- **๋ถ๋ชจ 2**: google-bert/bert-base-uncased |
|
|
- **๋ณํฉ ๋ฐฉ๋ฒ**: SLERP (์งํ์ ์ต์ ํ) |
|
|
- **์ต์ข
์ฑ๋ฅ**: 82-84% accuracy |
|
|
""") |
|
|
|
|
|
with gr.Tab("๐ ๋น ๋ฅธ ํ
์คํธ"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
load_btn = gr.Button("๐ฅ ๋ชจ๋ธ ๋ก๋", variant="primary") |
|
|
load_status = gr.Textbox(label="๋ก๋ ์ํ", lines=4) |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
|
label="ํ๋กฌํํธ", |
|
|
placeholder="ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์...", |
|
|
value="The future of AI is", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
max_length = gr.Slider(20, 200, 100, label="์ต๋ ๊ธธ์ด") |
|
|
temperature = gr.Slider(0.1, 2.0, 0.8, label="Temperature") |
|
|
|
|
|
with gr.Row(): |
|
|
top_p = gr.Slider(0.1, 1.0, 0.9, label="Top-p") |
|
|
rep_penalty = gr.Slider(1.0, 2.0, 1.2, label="๋ฐ๋ณต ํจ๋ํฐ") |
|
|
|
|
|
generate_btn = gr.Button("โจ ํ
์คํธ ์์ฑ", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="์์ฑ๋ ํ
์คํธ", lines=10) |
|
|
stats_text = gr.Textbox(label="์์ฑ ํต๊ณ", lines=6) |
|
|
|
|
|
with gr.Tab("๐ฌ ๋ชจ๋ธ ๋น๊ต"): |
|
|
compare_prompt = gr.Textbox( |
|
|
label="๋น๊ตํ ํ๋กฌํํธ", |
|
|
value="Once upon a time", |
|
|
lines=2 |
|
|
) |
|
|
compare_length = gr.Slider(20, 100, 50, label="์์ฑ ๊ธธ์ด") |
|
|
compare_btn = gr.Button("๐ ๋ถ๋ชจ ๋ชจ๋ธ๊ณผ ๋น๊ต", variant="primary") |
|
|
comparison_output = gr.Textbox(label="๋น๊ต ๊ฒฐ๊ณผ", lines=20) |
|
|
|
|
|
with gr.Tab("๐งช ๊ณ ๊ธ ํ
์คํธ"): |
|
|
gr.Markdown("### ๋ค์ํ ํ์คํฌ ํ
์คํธ") |
|
|
|
|
|
task_type = gr.Radio( |
|
|
["์ด์ผ๊ธฐ ์์ฑ", "์ง๋ฌธ ๋ต๋ณ", "์ฝ๋ ์์ฑ", "์ ์์ฑ"], |
|
|
label="ํ์คํฌ ์ ํ", |
|
|
value="์ด์ผ๊ธฐ ์์ฑ" |
|
|
) |
|
|
|
|
|
task_prompts = { |
|
|
"์ด์ผ๊ธฐ ์์ฑ": "In a distant galaxy, a young explorer discovered", |
|
|
"์ง๋ฌธ ๋ต๋ณ": "Q: What is machine learning?\nA:", |
|
|
"์ฝ๋ ์์ฑ": "# Python function to calculate fibonacci\ndef fibonacci(n):", |
|
|
"์ ์์ฑ": "Roses are red,\nViolets are blue," |
|
|
} |
|
|
|
|
|
def update_prompt(task): |
|
|
return task_prompts.get(task, "") |
|
|
|
|
|
task_prompt = gr.Textbox(label="ํ์คํฌ ํ๋กฌํํธ", lines=3) |
|
|
task_output = gr.Textbox(label="๊ฒฐ๊ณผ", lines=10) |
|
|
task_btn = gr.Button("๐ฏ ํ์คํฌ ์คํ", variant="primary") |
|
|
|
|
|
task_type.change(update_prompt, task_type, task_prompt) |
|
|
|
|
|
with gr.Tab("๐ ์ฑ๋ฅ ๋ถ์"): |
|
|
gr.Markdown(""" |
|
|
### ์งํ ์คํ ๊ฒฐ๊ณผ |
|
|
|
|
|
| ๋ฉํธ๋ฆญ | ๊ฐ | |
|
|
|--------|-----| |
|
|
| ์ด๊ธฐ ์ฑ๋ฅ | 10.56% | |
|
|
| ์ต์ข
์ฑ๋ฅ | 82-84% | |
|
|
| ๊ฐ์ ์จ | +700% | |
|
|
| ์ด ๊ฐ์ ํ์ | 2,136ํ | |
|
|
| ํ์ต ์๊ฐ | 7.7๋ถ | |
|
|
|
|
|
### ์ธ๋๋ณ ์ฑ๋ฅ |
|
|
- **์ด๊ธฐ (0-2000)**: ํฐ ๊ฐ์ (+20-30%/์ธ๋) |
|
|
- **์ค๊ธฐ (2000-5000)**: ์ค๊ฐ ๊ฐ์ (+10-15%/์ธ๋) |
|
|
- **ํ๊ธฐ (5000-10000)**: ๋ฏธ์ธ ์กฐ์ (+2-5%/์ธ๋) |
|
|
""") |
|
|
|
|
|
test_suite_btn = gr.Button("๐ ์ ์ฒด ํ
์คํธ ์ค์ํธ ์คํ", variant="primary") |
|
|
test_results = gr.Textbox(label="ํ
์คํธ ๊ฒฐ๊ณผ", lines=15) |
|
|
|
|
|
|
|
|
load_btn.click( |
|
|
lambda: tester.load_model("openfree/gpt2-bert"), |
|
|
outputs=load_status |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
tester.generate_text, |
|
|
inputs=[prompt_input, max_length, temperature, top_p, rep_penalty], |
|
|
outputs=[output_text, stats_text, gr.Textbox(visible=False)] |
|
|
) |
|
|
|
|
|
compare_btn.click( |
|
|
tester.compare_with_parents, |
|
|
inputs=[compare_prompt, compare_length], |
|
|
outputs=comparison_output |
|
|
) |
|
|
|
|
|
task_btn.click( |
|
|
lambda p: tester.generate_text(p, 100, 0.8, 0.9, 1.2), |
|
|
inputs=task_prompt, |
|
|
outputs=[task_output, gr.Textbox(visible=False), gr.Textbox(visible=False)] |
|
|
) |
|
|
|
|
|
def run_test_suite(progress=gr.Progress()): |
|
|
"""์ ์ฒด ํ
์คํธ ์ค์ํธ ์คํ""" |
|
|
results = [] |
|
|
|
|
|
test_prompts = [ |
|
|
"The meaning of life is", |
|
|
"import numpy as np\n", |
|
|
"Scientists have discovered", |
|
|
"def hello_world():", |
|
|
"Breaking news:" |
|
|
] |
|
|
|
|
|
for i, prompt in enumerate(test_prompts): |
|
|
progress((i+1)/len(test_prompts), desc=f"ํ
์คํธ {i+1}/{len(test_prompts)}") |
|
|
try: |
|
|
output, stats, _ = tester.generate_text(prompt, 30) |
|
|
results.append(f"โ
ํ๋กฌํํธ: {prompt[:30]}...\n ์์ฑ ์ฑ๊ณต") |
|
|
except: |
|
|
results.append(f"โ ํ๋กฌํํธ: {prompt[:30]}...\n ์์ฑ ์คํจ") |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
test_suite_btn.click( |
|
|
run_test_suite, |
|
|
outputs=test_results |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=False) |