young476 commited on
Commit
d315a55
ยท
1 Parent(s): 92fe661

Add application file

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -1,17 +1,32 @@
1
- import os
2
  from transformers import BertForSequenceClassification, BertTokenizer
 
3
 
4
- def save_model_and_tokenizer(model, tokenizer, save_dir='my_kobert_model'):
5
- if not os.path.exists(save_dir):
6
- os.makedirs(save_dir)
7
- model.save_pretrained(save_dir)
8
- tokenizer.save_pretrained(save_dir)
9
- print(f"๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๊ฐ€ '{save_dir}' ํด๋”์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
10
 
11
- if __name__ == '__main__':
12
- # ์ด๋ฏธ ํ•™์Šต๋œ ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •ํ•˜์„ธ์š”
13
- model_name_or_path = 'bert-base-uncased' # ๋˜๋Š” ์ง์ ‘ ํ•™์Šตํ•œ ๋ชจ๋ธ ๊ฒฝ๋กœ
14
- model = BertForSequenceClassification.from_pretrained(model_name_or_path)
15
- tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
16
 
17
- save_model_and_tokenizer(model, tokenizer, save_dir='my_kobert_model')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import BertForSequenceClassification, BertTokenizer
3
+ import torch
4
 
5
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (๋ณธ์ธ Hugging Face ๋ชจ๋ธ ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •)
6
+ MODEL_NAME = "young476/LyricToGenre0607"
7
+ model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
8
+ tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
 
 
9
 
10
+ # ํด๋ž˜์Šค ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ (์˜ˆ์‹œ: ์‹ค์ œ ์žฅ๋ฅด๋ช…์œผ๋กœ ์ˆ˜์ •)
11
+ genre_labels = ["๋ฐœ๋ผ๋“œ", "๋Œ„์Šค", "ํž™ํ•ฉ", "๋ก", "ํŠธ๋กœํŠธ", "R&B"]
 
 
 
12
 
13
+ def predict_genre(lyrics):
14
+ inputs = tokenizer(lyrics, return_tensors="pt", truncation=True, max_length=256)
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ pred_id = outputs.logits.argmax(dim=1).item()
18
+ pred_label = genre_labels[pred_id]
19
+ probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
20
+ prob_dict = {genre_labels[i]: float(probs[i]) for i in range(len(genre_labels))}
21
+ return pred_label, prob_dict
22
+
23
+ demo = gr.Interface(
24
+ fn=predict_genre,
25
+ inputs=gr.Textbox(lines=8, label="๊ฐ€์‚ฌ ์ž…๋ ฅ"),
26
+ outputs=[gr.Label(num_top_classes=1, label="์˜ˆ์ธก ์žฅ๋ฅด"), gr.Label(label="์žฅ๋ฅด๋ณ„ ํ™•๋ฅ ")],
27
+ title="๊ฐ€์‚ฌ ๊ธฐ๋ฐ˜ ์žฅ๋ฅด ๋ถ„๋ฅ˜๊ธฐ",
28
+ description="ํ•œ๊ตญ ๋…ธ๋ž˜ ๊ฐ€์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ์žฅ๋ฅด๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค."
29
+ )
30
+
31
+ if __name__ == "__main__":
32
+ demo.launch()