nairut commited on
Commit
5f557e8
·
verified ·
1 Parent(s): c9679e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -40
app.py CHANGED
@@ -1,50 +1,42 @@
1
  import os
2
  import torch
3
- from torch import nn
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel, Field
6
  from comet import load_from_checkpoint
7
  from huggingface_hub import snapshot_download, HfApi
8
 
9
  # ==========================================================
10
- # ⚙️ Configuração de memória do PyTorch
11
- # ==========================================================
12
- os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
13
-
14
- # ==========================================================
15
- # 🚀 Configuração da API
16
  # ==========================================================
17
  app = FastAPI(
18
  title="XCOMET-XXL API",
19
- version="1.5.0",
20
- description="API para avaliação de traduções usando Unbabel/XCOMET-XXL, "
21
- "compatível com campos 'source', 'target' e 'human_translation_ref'."
22
  )
23
 
24
  MODEL_NAME = "Unbabel/XCOMET-XXL"
25
  HF_TOKEN = os.environ.get("HF_TOKEN") # defina nas Secrets do Space
26
- SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/comet-xxl")
27
 
28
- # ==========================================================
29
- # 📂 Caminho persistente (150 GB Medium Storage)
30
- # ==========================================================
31
- MODEL_DIR = "/data/model"
32
  MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt")
33
 
34
 
35
  # ==========================================================
36
- # ⚙️ Função auxiliar: baixa e persiste o modelo
37
  # ==========================================================
38
  def ensure_model_persisted_once():
39
  """
40
- Faz o download do modelo XCOMET-XXL para /data/model (caso ainda não exista)
41
  e tenta commitar essa pasta no próprio Space, para persistência.
42
  """
43
  if os.path.exists(MODEL_CKPT):
44
  print(f"✅ Modelo já existe em {MODEL_CKPT}. Pulando download.")
45
  return
46
 
47
- print("🔽 Baixando snapshot do modelo para /data/model ...")
48
  snapshot_download(
49
  repo_id=MODEL_NAME,
50
  token=HF_TOKEN,
@@ -62,7 +54,7 @@ def ensure_model_persisted_once():
62
  repo_type="space",
63
  folder_path=MODEL_DIR,
64
  path_in_repo="model",
65
- commit_message="Persistência automática do modelo XCOMET-XXL"
66
  )
67
  print("✅ Modelo persistido no Space.")
68
  except Exception as e:
@@ -71,32 +63,26 @@ def ensure_model_persisted_once():
71
 
72
 
73
  # ==========================================================
74
- # 📦 Inicialização do modelo (multi-GPU + persistência)
 
 
 
 
 
 
 
 
 
 
75
  # ==========================================================
76
  ensure_model_persisted_once()
77
 
78
  print(f"📂 Carregando modelo de {MODEL_CKPT} ...")
79
-
80
- # Carrega o modelo (sem map_location, compatível com COMET)
81
  model = load_from_checkpoint(MODEL_CKPT)
82
- model.eval()
83
-
84
- # Detecta GPUs disponíveis
85
- num_gpus = torch.cuda.device_count()
86
- print(f"🎮 GPUs detectadas: {num_gpus}")
87
-
88
- # Se houver mais de uma GPU, ativa DataParallel
89
- if num_gpus > 1:
90
- print("⚙️ Ativando DataParallel para usar múltiplas GPUs...")
91
- model = nn.DataParallel(model)
92
-
93
- # Move o modelo para GPU (caso disponível)
94
- if torch.cuda.is_available():
95
- model.to("cuda")
96
- print("✅ Modelo XCOMET-XXL carregado e distribuído nas GPUs.")
97
- else:
98
- print("⚠️ Nenhuma GPU detectada. Rodando em CPU (lento).")
99
 
 
 
100
 
101
 
102
  # ==========================================================
@@ -191,4 +177,4 @@ def score_batch(pairs: list[TranslationPair]):
191
  # ==========================================================
192
  if __name__ == "__main__":
193
  import uvicorn
194
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import torch
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel, Field
5
  from comet import load_from_checkpoint
6
  from huggingface_hub import snapshot_download, HfApi
7
 
8
  # ==========================================================
9
+ # 🚀 Configuração da API
 
 
 
 
 
10
  # ==========================================================
11
  app = FastAPI(
12
  title="XCOMET-XXL API",
13
+ version="2.0.0",
14
+ description="API para avaliação de traduções usando Unbabel/XCOMET-XXL API "
15
+
16
  )
17
 
18
  MODEL_NAME = "Unbabel/XCOMET-XXL"
19
  HF_TOKEN = os.environ.get("HF_TOKEN") # defina nas Secrets do Space
20
+ SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/xcomet-xxl")
21
 
22
+ # Diretório de cache local (dentro do Space ou ambiente local)
23
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), "model")
 
 
24
  MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt")
25
 
26
 
27
  # ==========================================================
28
+ # ⚙️ Função auxiliar: baixa e persiste o modelo
29
  # ==========================================================
30
  def ensure_model_persisted_once():
31
  """
32
+ Faz o download do modelo COMETKiwi-DA-XXL para ./model (caso ainda não exista)
33
  e tenta commitar essa pasta no próprio Space, para persistência.
34
  """
35
  if os.path.exists(MODEL_CKPT):
36
  print(f"✅ Modelo já existe em {MODEL_CKPT}. Pulando download.")
37
  return
38
 
39
+ print("🔽 Baixando snapshot do modelo para ./model ...")
40
  snapshot_download(
41
  repo_id=MODEL_NAME,
42
  token=HF_TOKEN,
 
54
  repo_type="space",
55
  folder_path=MODEL_DIR,
56
  path_in_repo="model",
57
+ commit_message="Persistência automática do modelo COMETKiwi-DA-XXL"
58
  )
59
  print("✅ Modelo persistido no Space.")
60
  except Exception as e:
 
63
 
64
 
65
  # ==========================================================
66
+ # ♻️ Inicialização limpa
67
+ # ==========================================================
68
+ # Remove da memória qualquer modelo carregado anteriormente
69
+ if "model" in globals():
70
+ del model
71
+ torch.cuda.empty_cache()
72
+ print("🧹 Modelo anterior removido da memória.")
73
+
74
+
75
+ # ==========================================================
76
+ # 📦 Inicialização do modelo
77
  # ==========================================================
78
  ensure_model_persisted_once()
79
 
80
  print(f"📂 Carregando modelo de {MODEL_CKPT} ...")
 
 
81
  model = load_from_checkpoint(MODEL_CKPT)
82
+ print("✅ Modelo COMETKiwi-DA-XXL carregado com sucesso!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ USE_GPU = 1 if torch.cuda.is_available() else 0
85
+ print(f"⚙️ GPU detectada: {'sim' if USE_GPU else 'não'}")
86
 
87
 
88
  # ==========================================================
 
177
  # ==========================================================
178
  if __name__ == "__main__":
179
  import uvicorn
180
+ uvicorn.run(app, host="0.0.0.0", port=7860)