disk / app /main.py
DIVYA-NSHU99's picture
Update app/main.py
b2c073b verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import os
import nltk
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
# Ensure models are cached to the runtime disk
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
# Set NLTK data path (must match Dockerfile ENV)
nltk_data_path = "/tmp/.cache/nltk"
os.environ["NLTK_DATA"] = nltk_data_path
nltk.data.path.append(nltk_data_path)
# Import your analyzer (after setting paths)
from app.src.main import TrademarkAnalyzer
from app.src.linguistic import LinguisticAnalyzer
app = FastAPI(title="Trademark Descriptiveness API")
# Check that the data file exists
data_path = "app/data/descriptive_keywords.json"
if not os.path.exists(data_path):
print(f"Warning: Data file not found at {data_path}. Keyword overlap will be disabled.")
# Initialize analyzer (models not loaded yet)
analyzer = TrademarkAnalyzer(descriptive_keywords_path=data_path)
@app.on_event("startup")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type(Exception)
)
def warmup():
"""
Pre-download all required models and NLTK data with automatic retries.
"""
print("Warming up: Attempting to load models and NLTK data...")
# ---- NLTK data ----
# Download WordNet if missing
try:
nltk.data.find('corpora/wordnet')
print("✅ WordNet already present.")
except LookupError:
print("Downloading WordNet...")
nltk.download('wordnet', download_dir=nltk_data_path)
print("✅ WordNet downloaded.")
# Download Punkt tokenizer (used by sent_tokenize)
try:
nltk.data.find('tokenizers/punkt')
print("✅ Punkt tokenizer already present.")
except LookupError:
print("Downloading Punkt tokenizer...")
nltk.download('punkt', download_dir=nltk_data_path)
print("✅ Punkt tokenizer downloaded.")
# ---- spaCy model ----
print("Loading spaCy model...")
LinguisticAnalyzer._get_nlp()
print("✅ spaCy model loaded.")
# ---- Sentence‑transformer embedding model ----
if hasattr(analyzer, 'embedding') and hasattr(analyzer.embedding, 'model'):
print("Loading embedding model...")
_ = analyzer.embedding.model
print("✅ Embedding model ready.")
# ---- Cross‑encoder model ----
if hasattr(analyzer, 'cross_encoder') and hasattr(analyzer.cross_encoder, 'model'):
print("Loading cross-encoder model...")
_ = analyzer.cross_encoder.model
print("✅ Cross-encoder ready.")
print("✅ Warmup complete.")
class AnalyzeRequest(BaseModel):
mark: str
goods: str
goods_class: Optional[str] = None
class AnalyzeResponse(BaseModel):
descriptive_score: float
generic_score: float
reasons: list[str]
explanation: str
details: dict
@app.get("/")
def read_root():
return {"message": "Trademark API is running"}
@app.get("/health")
def health_check():
return {"status": "ok"}
@app.post("/analyze", response_model=AnalyzeResponse)
def analyze(request: AnalyzeRequest):
try:
result = analyzer.analyze(
mark=request.mark,
goods=request.goods,
goods_class=request.goods_class
)
return AnalyzeResponse(**result)
except Exception as e:
print(f"Error during analysis: {e}")
raise HTTPException(status_code=500, detail=str(e))