hiddenFront commited on
Commit
ec61894
ยท
verified ยท
1 Parent(s): e66afc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,11 +1,11 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import AutoTokenizer
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
  import pickle
6
  import os
7
- import psutil
8
  import sys
 
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
@@ -16,22 +16,27 @@ try:
16
  category = pickle.load(f)
17
  print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
18
  except FileNotFoundError:
19
- print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
20
  sys.exit(1)
21
 
22
  # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
  print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
25
 
 
 
 
 
 
26
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
27
  HF_MODEL_FILENAME = "textClassifierModel.pt"
28
 
29
- # ๋ฉ”๋ชจ๋ฆฌ ํ™•์ธ
30
  process = psutil.Process(os.getpid())
31
  mem_before = process.memory_info().rss / (1024 * 1024)
32
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
33
 
34
- # ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋ฐ ๋กœ๋“œ
35
  try:
36
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
37
  print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
@@ -39,14 +44,16 @@ try:
39
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
40
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
41
 
42
- model = torch.load(model_path, map_location=device) # ์ „์ฒด ๋ชจ๋ธ ๊ฐ์ฒด ๋กœ๋“œ
 
 
43
  model.eval()
44
 
45
  mem_after_load = process.memory_info().rss / (1024 * 1024)
46
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_load:.2f} MB")
47
- print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต")
48
  except Exception as e:
49
- print(f"โŒ Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
50
  sys.exit(1)
51
 
52
  # ์˜ˆ์ธก API
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import BertForSequenceClassification, AutoTokenizer
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
  import pickle
6
  import os
 
7
  import sys
8
+ import psutil
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
 
16
  category = pickle.load(f)
17
  print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
18
  except FileNotFoundError:
19
+ print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
20
  sys.exit(1)
21
 
22
  # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
  print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
25
 
26
+ # ๋ชจ๋ธ ๊ตฌ์กฐ ์žฌ์ •์˜
27
+ num_labels = len(category) # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค ์ˆ˜์— ๋”ฐ๋ผ
28
+ model = BertForSequenceClassification.from_pretrained("skt/kobert-base-v1", num_labels=num_labels)
29
+ model.to(device)
30
+
31
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
32
  HF_MODEL_FILENAME = "textClassifierModel.pt"
33
 
34
+ # ๋ฉ”๋ชจ๋ฆฌ ์ธก์ • ์ „
35
  process = psutil.Process(os.getpid())
36
  mem_before = process.memory_info().rss / (1024 * 1024)
37
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
38
 
39
+ # ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ
40
  try:
41
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
42
  print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
 
44
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
45
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
46
 
47
+ # state_dict ๋กœ๋“œ
48
+ state_dict = torch.load(model_path, map_location=device)
49
+ model.load_state_dict(state_dict)
50
  model.eval()
51
 
52
  mem_after_load = process.memory_info().rss / (1024 * 1024)
53
  print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_load:.2f} MB")
54
+ print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ค€๋น„ ์™„๋ฃŒ.")
55
  except Exception as e:
56
+ print(f"โŒ Error: ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
57
  sys.exit(1)
58
 
59
  # ์˜ˆ์ธก API