Add application file
Browse files
app.py
CHANGED
|
@@ -1,17 +1,32 @@
|
|
| 1 |
-
import
|
| 2 |
from transformers import BertForSequenceClassification, BertTokenizer
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
tokenizer.save_pretrained(save_dir)
|
| 9 |
-
print(f"๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๊ฐ '{save_dir}' ํด๋์ ์ ์ฅ๋์์ต๋๋ค.")
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
model_name_or_path = 'bert-base-uncased' # ๋๋ ์ง์ ํ์ตํ ๋ชจ๋ธ ๊ฒฝ๋ก
|
| 14 |
-
model = BertForSequenceClassification.from_pretrained(model_name_or_path)
|
| 15 |
-
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
from transformers import BertForSequenceClassification, BertTokenizer
|
| 3 |
+
import torch
|
| 4 |
|
| 5 |
+
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ถ๋ฌ์ค๊ธฐ (๋ณธ์ธ Hugging Face ๋ชจ๋ธ ๊ฒฝ๋ก๋ก ์์ )
|
| 6 |
+
MODEL_NAME = "young476/LyricToGenre0607"
|
| 7 |
+
model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
|
| 8 |
+
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
# ํด๋์ค ์ด๋ฆ ๋ฆฌ์คํธ (์์: ์ค์ ์ฅ๋ฅด๋ช
์ผ๋ก ์์ )
|
| 11 |
+
genre_labels = ["๋ฐ๋ผ๋", "๋์ค", "ํํฉ", "๋ก", "ํธ๋กํธ", "R&B"]
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
def predict_genre(lyrics):
|
| 14 |
+
inputs = tokenizer(lyrics, return_tensors="pt", truncation=True, max_length=256)
|
| 15 |
+
with torch.no_grad():
|
| 16 |
+
outputs = model(**inputs)
|
| 17 |
+
pred_id = outputs.logits.argmax(dim=1).item()
|
| 18 |
+
pred_label = genre_labels[pred_id]
|
| 19 |
+
probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
|
| 20 |
+
prob_dict = {genre_labels[i]: float(probs[i]) for i in range(len(genre_labels))}
|
| 21 |
+
return pred_label, prob_dict
|
| 22 |
+
|
| 23 |
+
demo = gr.Interface(
|
| 24 |
+
fn=predict_genre,
|
| 25 |
+
inputs=gr.Textbox(lines=8, label="๊ฐ์ฌ ์
๋ ฅ"),
|
| 26 |
+
outputs=[gr.Label(num_top_classes=1, label="์์ธก ์ฅ๋ฅด"), gr.Label(label="์ฅ๋ฅด๋ณ ํ๋ฅ ")],
|
| 27 |
+
title="๊ฐ์ฌ ๊ธฐ๋ฐ ์ฅ๋ฅด ๋ถ๋ฅ๊ธฐ",
|
| 28 |
+
description="ํ๊ตญ ๋
ธ๋ ๊ฐ์ฌ๋ฅผ ์
๋ ฅํ๋ฉด ์ฅ๋ฅด๋ฅผ ์์ธกํฉ๋๋ค."
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
demo.launch()
|