DIVYA-NSHU99 commited on
Commit
b2c073b
·
verified ·
1 Parent(s): dd15286

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +32 -8
app/main.py CHANGED
@@ -2,12 +2,18 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional
4
  import os
 
5
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
6
 
7
  # Ensure models are cached to the runtime disk
8
  os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
9
 
10
- # Import your analyzer
 
 
 
 
 
11
  from app.src.main import TrademarkAnalyzer
12
  from app.src.linguistic import LinguisticAnalyzer
13
 
@@ -25,27 +31,45 @@ analyzer = TrademarkAnalyzer(descriptive_keywords_path=data_path)
25
  @retry(
26
  stop=stop_after_attempt(3),
27
  wait=wait_exponential(multiplier=1, min=2, max=10),
28
- retry=retry_if_exception_type(Exception) # Catch any exception during warmup
29
  )
30
  def warmup():
31
  """
32
- Pre-download all required models with automatic retries.
33
- This handles transient network issues and corrupted cache.
34
  """
35
- print("Warming up: Attempting to load models...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # 1. Load spaCy model
38
  print("Loading spaCy model...")
39
  LinguisticAnalyzer._get_nlp()
40
  print("✅ spaCy model loaded.")
41
 
42
- # 2. Preload sentence‑transformer model (embedding)
43
  if hasattr(analyzer, 'embedding') and hasattr(analyzer.embedding, 'model'):
44
  print("Loading embedding model...")
45
  _ = analyzer.embedding.model
46
  print("✅ Embedding model ready.")
47
 
48
- # 3. Preload cross‑encoder model
49
  if hasattr(analyzer, 'cross_encoder') and hasattr(analyzer.cross_encoder, 'model'):
50
  print("Loading cross-encoder model...")
51
  _ = analyzer.cross_encoder.model
 
2
  from pydantic import BaseModel
3
  from typing import Optional
4
  import os
5
+ import nltk
6
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
7
 
8
  # Ensure models are cached to the runtime disk
9
  os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
10
 
11
+ # Set NLTK data path (must match Dockerfile ENV)
12
+ nltk_data_path = "/tmp/.cache/nltk"
13
+ os.environ["NLTK_DATA"] = nltk_data_path
14
+ nltk.data.path.append(nltk_data_path)
15
+
16
+ # Import your analyzer (after setting paths)
17
  from app.src.main import TrademarkAnalyzer
18
  from app.src.linguistic import LinguisticAnalyzer
19
 
 
31
  @retry(
32
  stop=stop_after_attempt(3),
33
  wait=wait_exponential(multiplier=1, min=2, max=10),
34
+ retry=retry_if_exception_type(Exception)
35
  )
36
  def warmup():
37
  """
38
+ Pre-download all required models and NLTK data with automatic retries.
 
39
  """
40
+ print("Warming up: Attempting to load models and NLTK data...")
41
+
42
+ # ---- NLTK data ----
43
+ # Download WordNet if missing
44
+ try:
45
+ nltk.data.find('corpora/wordnet')
46
+ print("✅ WordNet already present.")
47
+ except LookupError:
48
+ print("Downloading WordNet...")
49
+ nltk.download('wordnet', download_dir=nltk_data_path)
50
+ print("✅ WordNet downloaded.")
51
+
52
+ # Download Punkt tokenizer (used by sent_tokenize)
53
+ try:
54
+ nltk.data.find('tokenizers/punkt')
55
+ print("✅ Punkt tokenizer already present.")
56
+ except LookupError:
57
+ print("Downloading Punkt tokenizer...")
58
+ nltk.download('punkt', download_dir=nltk_data_path)
59
+ print("✅ Punkt tokenizer downloaded.")
60
 
61
+ # ---- spaCy model ----
62
  print("Loading spaCy model...")
63
  LinguisticAnalyzer._get_nlp()
64
  print("✅ spaCy model loaded.")
65
 
66
+ # ---- Sentence‑transformer embedding model ----
67
  if hasattr(analyzer, 'embedding') and hasattr(analyzer.embedding, 'model'):
68
  print("Loading embedding model...")
69
  _ = analyzer.embedding.model
70
  print("✅ Embedding model ready.")
71
 
72
+ # ---- Cross‑encoder model ----
73
  if hasattr(analyzer, 'cross_encoder') and hasattr(analyzer.cross_encoder, 'model'):
74
  print("Loading cross-encoder model...")
75
  _ = analyzer.cross_encoder.model