hiddenFront commited on
Commit
4607c9c
·
verified ·
1 Parent(s): 9dd37b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -37,11 +37,6 @@ class CustomClassifier(torch.nn.Module):
37
  pooled_output = outputs[1] # CLS 토큰
38
  return self.classifier(pooled_output)
39
 
40
-
41
- model = CustomClassifier()
42
- model.load_state_dict(torch.load(model_path, map_location=device))
43
- model.eval()
44
-
45
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
46
  HF_MODEL_FILENAME = "textClassifierModel.pt"
47
 
@@ -50,7 +45,6 @@ process = psutil.Process(os.getpid())
50
  mem_before = process.memory_info().rss / (1024 * 1024)
51
  print(f"📦 모델 다운로드 전 메모리 사용량: {mem_before:.2f} MB")
52
 
53
- # 모델 가중치 다운로드
54
  try:
55
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
56
  print(f"✅ 모델 파일 다운로드 성공: {model_path}")
@@ -58,7 +52,8 @@ try:
58
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
59
  print(f"📦 모델 다운로드 후 메모리 사용량: {mem_after_dl:.2f} MB")
60
 
61
- # state_dict 로드
 
62
  state_dict = torch.load(model_path, map_location=device)
63
  model.load_state_dict(state_dict)
64
  model.eval()
@@ -70,6 +65,7 @@ except Exception as e:
70
  print(f"❌ Error: 모델 로드 중 오류 발생: {e}")
71
  sys.exit(1)
72
 
 
73
  # 예측 API
74
  @app.post("/predict")
75
  async def predict_api(request: Request):
 
37
  pooled_output = outputs[1] # CLS 토큰
38
  return self.classifier(pooled_output)
39
 
 
 
 
 
 
40
  HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
41
  HF_MODEL_FILENAME = "textClassifierModel.pt"
42
 
 
45
  mem_before = process.memory_info().rss / (1024 * 1024)
46
  print(f"📦 모델 다운로드 전 메모리 사용량: {mem_before:.2f} MB")
47
 
 
48
  try:
49
  model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
50
  print(f"✅ 모델 파일 다운로드 성공: {model_path}")
 
52
  mem_after_dl = process.memory_info().rss / (1024 * 1024)
53
  print(f"📦 모델 다운로드 후 메모리 사용량: {mem_after_dl:.2f} MB")
54
 
55
+ # 모델 구성 및 state_dict 로드
56
+ model = CustomClassifier()
57
  state_dict = torch.load(model_path, map_location=device)
58
  model.load_state_dict(state_dict)
59
  model.eval()
 
65
  print(f"❌ Error: 모델 로드 중 오류 발생: {e}")
66
  sys.exit(1)
67
 
68
+
69
  # 예측 API
70
  @app.post("/predict")
71
  async def predict_api(request: Request):