hiddenFront commited on
Commit
e66afc2
ยท
verified ยท
1 Parent(s): 4e6dd7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -42
app.py CHANGED
@@ -1,13 +1,11 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
- import numpy as np
6
  import pickle
 
 
7
  import sys
8
- import collections
9
- import os # os ๋ชจ๋“ˆ ์ž„ํฌํŠธ
10
- import psutil # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ํ™•์ธ์„ ์œ„ํ•ด psutil ์ž„ํฌํŠธ (requirements.txt์— ์ถ”๊ฐ€ ํ•„์š”)
11
 
12
  app = FastAPI()
13
  device = torch.device("cpu")
@@ -16,61 +14,42 @@ device = torch.device("cpu")
16
  try:
17
  with open("category.pkl", "rb") as f:
18
  category = pickle.load(f)
19
- print("category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
20
  except FileNotFoundError:
21
- print("Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
22
  sys.exit(1)
23
 
24
  # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
25
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
26
- print("ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
27
 
28
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
29
  HF_MODEL_FILENAME = "textClassifierModel.pt"
30
 
31
- # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ์‹œ์ž‘ ---
32
  process = psutil.Process(os.getpid())
33
- mem_before_model_download = process.memory_info().rss / (1024 * 1024) # MB ๋‹จ์œ„
34
- print(f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before_model_download:.2f} MB")
35
- # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ๋ ---
36
 
 
37
  try:
38
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
39
- print(f"๋ชจ๋ธ ํŒŒ์ผ์ด '{model_path}'์— ์„ฑ๊ณต์ ์œผ๋กœ ๋‹ค์šด๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
40
 
41
- # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ์‹œ์ž‘ ---
42
- mem_after_model_download = process.memory_info().rss / (1024 * 1024) # MB ๋‹จ์œ„
43
- print(f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_model_download:.2f} MB")
44
- # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ๋ ---
45
-
46
- # 1. ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์ •์˜ (๊ฐ€์ค‘์น˜๋Š” ๋กœ๋“œํ•˜์ง€ ์•Š๊ณ  ๊ตฌ์กฐ๋งŒ ์ดˆ๊ธฐํ™”)
47
- config = BertConfig.from_pretrained("skt/kobert-base-v1", num_labels=len(category))
48
- model = BertForSequenceClassification(config)
49
-
50
- # 2. ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ์—์„œ state_dict๋ฅผ ๋กœ๋“œ
51
- loaded_state_dict = torch.load(model_path, map_location=device)
52
-
53
- # 3. ๋กœ๋“œ๋œ state_dict๋ฅผ ์ •์˜๋œ ๋ชจ๋ธ์— ์ ์šฉ
54
- new_state_dict = collections.OrderedDict()
55
- for k, v in loaded_state_dict.items():
56
- name = k
57
- if name.startswith('module.'):
58
- name = name[7:]
59
- new_state_dict[name] = v
60
-
61
- model.load_state_dict(new_state_dict)
62
-
63
- # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ์‹œ์ž‘ ---
64
- mem_after_model_load = process.memory_info().rss / (1024 * 1024) # MB ๋‹จ์œ„
65
- print(f"๋ชจ๋ธ ๋กœ๋“œ ๋ฐ state_dict ์ ์šฉ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_model_load:.2f} MB")
66
- # --- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋กœ๊น… ๋ ---
67
 
 
68
  model.eval()
69
- print("๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต.")
 
 
 
70
  except Exception as e:
71
- print(f"Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
72
  sys.exit(1)
73
 
 
74
  @app.post("/predict")
75
  async def predict_api(request: Request):
76
  data = await request.json()
@@ -86,6 +65,6 @@ async def predict_api(request: Request):
86
  outputs = model(**encoded)
87
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
88
  predicted = torch.argmax(probs, dim=1).item()
89
-
90
  label = list(category.keys())[predicted]
91
  return {"text": text, "classification": label}
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer
3
  from huggingface_hub import hf_hub_download
4
  import torch
 
5
  import pickle
6
+ import os
7
+ import psutil
8
  import sys
 
 
 
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
 
14
  try:
15
  with open("category.pkl", "rb") as f:
16
  category = pickle.load(f)
17
+ print("โœ… category.pkl ๋กœ๋“œ ์„ฑ๊ณต.")
18
  except FileNotFoundError:
19
+ print("โŒ Error: category.pkl ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ์— ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
20
  sys.exit(1)
21
 
22
  # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
23
  tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
+ print("โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ์„ฑ๊ณต.")
25
 
26
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
27
  HF_MODEL_FILENAME = "textClassifierModel.pt"
28
 
29
+ # ๋ฉ”๋ชจ๋ฆฌ ํ™•์ธ
30
  process = psutil.Process(os.getpid())
31
+ mem_before = process.memory_info().rss / (1024 * 1024)
32
+ print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_before:.2f} MB")
 
33
 
34
+ # ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋ฐ ๋กœ๋“œ
35
  try:
36
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
37
+ print(f"โœ… ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต: {model_path}")
38
 
39
+ mem_after_dl = process.memory_info().rss / (1024 * 1024)
40
+ print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_dl:.2f} MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ model = torch.load(model_path, map_location=device) # ์ „์ฒด ๋ชจ๋ธ ๊ฐ์ฒด ๋กœ๋“œ
43
  model.eval()
44
+
45
+ mem_after_load = process.memory_info().rss / (1024 * 1024)
46
+ print(f"๐Ÿ“ฆ ๋ชจ๋ธ ๋กœ๋“œ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {mem_after_load:.2f} MB")
47
+ print("โœ… ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต")
48
  except Exception as e:
49
+ print(f"โŒ Error: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
50
  sys.exit(1)
51
 
52
+ # ์˜ˆ์ธก API
53
  @app.post("/predict")
54
  async def predict_api(request: Request):
55
  data = await request.json()
 
65
  outputs = model(**encoded)
66
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
67
  predicted = torch.argmax(probs, dim=1).item()
68
+
69
  label = list(category.keys())[predicted]
70
  return {"text": text, "classification": label}