Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import os | |
| import nltk | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| # Ensure models are cached to the runtime disk | |
| os.environ["HF_HOME"] = "/tmp/.cache/huggingface" | |
| # Set NLTK data path (must match Dockerfile ENV) | |
| nltk_data_path = "/tmp/.cache/nltk" | |
| os.environ["NLTK_DATA"] = nltk_data_path | |
| nltk.data.path.append(nltk_data_path) | |
| # Import your analyzer (after setting paths) | |
| from app.src.main import TrademarkAnalyzer | |
| from app.src.linguistic import LinguisticAnalyzer | |
| app = FastAPI(title="Trademark Descriptiveness API") | |
| # Check that the data file exists | |
| data_path = "app/data/descriptive_keywords.json" | |
| if not os.path.exists(data_path): | |
| print(f"Warning: Data file not found at {data_path}. Keyword overlap will be disabled.") | |
| # Initialize analyzer (models not loaded yet) | |
| analyzer = TrademarkAnalyzer(descriptive_keywords_path=data_path) | |
| def warmup(): | |
| """ | |
| Pre-download all required models and NLTK data with automatic retries. | |
| """ | |
| print("Warming up: Attempting to load models and NLTK data...") | |
| # ---- NLTK data ---- | |
| # Download WordNet if missing | |
| try: | |
| nltk.data.find('corpora/wordnet') | |
| print("✅ WordNet already present.") | |
| except LookupError: | |
| print("Downloading WordNet...") | |
| nltk.download('wordnet', download_dir=nltk_data_path) | |
| print("✅ WordNet downloaded.") | |
| # Download Punkt tokenizer (used by sent_tokenize) | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| print("✅ Punkt tokenizer already present.") | |
| except LookupError: | |
| print("Downloading Punkt tokenizer...") | |
| nltk.download('punkt', download_dir=nltk_data_path) | |
| print("✅ Punkt tokenizer downloaded.") | |
| # ---- spaCy model ---- | |
| print("Loading spaCy model...") | |
| LinguisticAnalyzer._get_nlp() | |
| print("✅ spaCy model loaded.") | |
| # ---- Sentence‑transformer embedding model ---- | |
| if hasattr(analyzer, 'embedding') and hasattr(analyzer.embedding, 'model'): | |
| print("Loading embedding model...") | |
| _ = analyzer.embedding.model | |
| print("✅ Embedding model ready.") | |
| # ---- Cross‑encoder model ---- | |
| if hasattr(analyzer, 'cross_encoder') and hasattr(analyzer.cross_encoder, 'model'): | |
| print("Loading cross-encoder model...") | |
| _ = analyzer.cross_encoder.model | |
| print("✅ Cross-encoder ready.") | |
| print("✅ Warmup complete.") | |
| class AnalyzeRequest(BaseModel): | |
| mark: str | |
| goods: str | |
| goods_class: Optional[str] = None | |
| class AnalyzeResponse(BaseModel): | |
| descriptive_score: float | |
| generic_score: float | |
| reasons: list[str] | |
| explanation: str | |
| details: dict | |
| def read_root(): | |
| return {"message": "Trademark API is running"} | |
| def health_check(): | |
| return {"status": "ok"} | |
| def analyze(request: AnalyzeRequest): | |
| try: | |
| result = analyzer.analyze( | |
| mark=request.mark, | |
| goods=request.goods, | |
| goods_class=request.goods_class | |
| ) | |
| return AnalyzeResponse(**result) | |
| except Exception as e: | |
| print(f"Error during analysis: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) |