young476 commited on
Commit
b8356c0
ยท
1 Parent(s): 7192baa

Edit app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -13
app.py CHANGED
@@ -2,27 +2,72 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
3
  import torch
4
 
5
- MODEL_PATH = "./"
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
8
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
9
- pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
10
 
11
- def predict(lyrics):
12
- result = pipeline(lyrics)
13
- # ์˜ˆ์˜๊ฒŒ ๊ฒฐ๊ณผ๋งŒ ์ถ”์ถœ
14
- if isinstance(result, list) and len(result) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  label = result[0].get("label", "Unknown")
16
- score = result[0].get("score", 0)
 
17
  return f"{label} ({score:.2f})"
18
- return "No result"
 
 
19
 
 
20
  demo = gr.Interface(
21
  fn=predict,
22
- inputs=gr.Textbox(label="Lyrics"),
23
  outputs=gr.Textbox(label="Predicted Genre"),
24
- title="Lyrics Genre Predictor"
 
 
25
  )
26
 
27
  if __name__ == "__main__":
28
- demo.launch()
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
3
  import torch
4
 
5
+ print("Gradio app script started.") # ์Šคํฌ๋ฆฝํŠธ ์‹œ์ž‘ ๋กœ๊ทธ
6
 
7
+ MODEL_PATH = "./" # ๋ชจ๋ธ ํŒŒ์ผ์ด ์ปจํ…Œ์ด๋„ˆ ๋‚ด app.py์™€ ๋™์ผํ•œ ๋””๋ ‰ํ† ๋ฆฌ์— ์žˆ๋‹ค๊ณ  ๊ฐ€์ •
 
 
8
 
9
+ pipeline_instance = None # ํŒŒ์ดํ”„๋ผ์ธ ์ธ์Šคํ„ด์Šค๋ฅผ ์ „์—ญ์ ์œผ๋กœ ์„ ์–ธ
10
+
11
+ try:
12
+ print(f"Loading tokenizer from: {MODEL_PATH}")
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
14
+ print(f"Loading model from: {MODEL_PATH}")
15
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
16
+ device_to_use = 0 if torch.cuda.is_available() else -1 # GPU ์šฐ์„  ์‚ฌ์šฉ, ์—†์œผ๋ฉด CPU
17
+ print(f"Using device: {'cuda:0' if device_to_use == 0 else 'cpu'}")
18
+ pipeline_instance = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=device_to_use)
19
+ print("Pipeline created successfully.")
20
+ except Exception as e:
21
+ print(f"Error loading model, tokenizer, or creating pipeline: {e}")
22
+ # ํŒŒ์ดํ”„๋ผ์ธ ์ƒ์„ฑ ์‹คํŒจ ์‹œ ์•ฑ์ด ๊ณ„์† ์‹คํ–‰๋˜๋„๋ก ํ•˜์ง€๋งŒ, predict ํ•จ์ˆ˜์—์„œ ์ฒ˜๋ฆฌ
23
+
24
+ def predict(lyrics: str): # ์ž…๋ ฅ ํƒ€์ž… ๋ช…์‹œ (Python 3.9+ ์—์„œ ๊ถŒ์žฅ)
25
+ print(f"--- PREDICT FUNCTION CALLED ---")
26
+ print(f"Received lyrics: '{lyrics}' (Type: {type(lyrics)})")
27
+
28
+ if pipeline_instance is None:
29
+ print("Pipeline is not initialized. Cannot predict.")
30
+ return "์˜ค๋ฅ˜: ๋ชจ๋ธ ํŒŒ์ดํ”„๋ผ์ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. (0.00)"
31
+
32
+ if not lyrics or not isinstance(lyrics, str) or lyrics.strip() == "":
33
+ print("Lyrics are empty, not a string, or whitespace only.")
34
+ return "์ž…๋ ฅ ๊ฐ€์‚ฌ๊ฐ€ ๋น„์–ด์žˆ๊ฑฐ๋‚˜ ์œ ํšจํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. (0.00)"
35
+
36
+ try:
37
+ print("Performing prediction with pipeline...")
38
+ result = pipeline_instance(lyrics) # ์ „์—ญ pipeline_instance ์‚ฌ์šฉ
39
+ print(f"Pipeline raw result: {result}")
40
+ except Exception as e:
41
+ print(f"Error during pipeline prediction: {e}")
42
+ return f"์˜ˆ์ธก ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)} (0.00)" # ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํฌํ•จ
43
+
44
+ # ๊ฒฐ๊ณผ ์ถ”์ถœ ๋กœ์ง (์ด์ „๊ณผ ๋™์ผ)
45
+ if isinstance(result, list) and len(result) > 0 and isinstance(result[0], dict):
46
  label = result[0].get("label", "Unknown")
47
+ score = result[0].get("score", 0.0) # ์ ์ˆ˜๊ฐ€ float์ธ์ง€ ํ™•์ธ
48
+ print(f"Extracted prediction: Label='{label}', Score={score}")
49
  return f"{label} ({score:.2f})"
50
+ else:
51
+ print(f"Pipeline returned no result or unexpected format: {result}")
52
+ return "๊ฒฐ๊ณผ๋ฅผ ์ถ”์ถœํ•  ์ˆ˜ ์—†๊ฑฐ๋‚˜ ํ˜•์‹์ด ์˜ฌ๋ฐ”๋ฅด์ง€ ์•Š์Šต๋‹ˆ๋‹ค. (0.00)"
53
 
54
+ print("Creating Gradio Interface...")
55
  demo = gr.Interface(
56
  fn=predict,
57
+ inputs=gr.Textbox(label="Lyrics", placeholder="์—ฌ๊ธฐ์— ๊ฐ€์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”..."),
58
  outputs=gr.Textbox(label="Predicted Genre"),
59
+ title="Lyrics Genre Predictor (Local Docker)",
60
+ description="๊ฐ€์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ๋กœ์ปฌ Docker ์ปจํ…Œ์ด๋„ˆ์—์„œ ์‹คํ–‰ ์ค‘์ธ ๋ชจ๋ธ์ด ์žฅ๋ฅด๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค."
61
+ # api_name="predict" # ๋ช…์‹œ์ ์œผ๋กœ API ์ด๋ฆ„์„ ์„ค์ •ํ•˜๋ฉด /api/predict ์—”๋“œํฌ์ธํŠธ๊ฐ€ ํ™•์‹คํžˆ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
62
  )
63
 
64
  if __name__ == "__main__":
65
+ print("Launching Gradio app...")
66
+ # API ์ ‘๊ทผ์„ ์œ„ํ•ด์„œ๋Š” queue()๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  server_name์„ ์„ค์ •ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
67
+ # server_name="0.0.0.0"์€ Docker ์ปจํ…Œ์ด๋„ˆ ์™ธ๋ถ€(์˜ˆ: ํ˜ธ์ŠคํŠธ์˜ Next.js)์—์„œ ์ ‘๊ทผ ํ—ˆ์šฉ
68
+ # server_port=7860์€ Gradio ๊ธฐ๋ณธ ํฌํŠธ
69
+ # queue() ์‚ฌ์šฉ ์‹œ api_open=True๊ฐ€ ๊ธฐ๋ณธ๊ฐ’์ธ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์•„ /api/predict ์—”๋“œํฌ์ธํŠธ๊ฐ€ ํ™œ์„ฑํ™”๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Œ
70
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
71
+ # ๋˜๋Š”, queue() ์—†์ด api_name์„ ๋ช…์‹œ:
72
+ # demo.launch(server_name="0.0.0.0", server_port=7860, api_name="predict")
73
+ print(f"Gradio app launched. Access UI at http://localhost:7860. API (likely) at http://localhost:7860/api/predict")