File size: 4,421 Bytes
e4eb82b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b82bbd8
 
 
 
 
 
e4eb82b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936e189
 
e4eb82b
 
936e189
 
 
 
 
 
e4eb82b
936e189
 
 
 
 
 
 
e4eb82b
936e189
 
 
 
 
e4eb82b
936e189
e4eb82b
936e189
 
 
 
e4eb82b
 
b82bbd8
e4eb82b
 
b82bbd8
 
 
e4eb82b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Main FastAPI application with clean architecture
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import logging
import os
from dotenv import load_dotenv
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

# Load environment variables from .env file
load_dotenv()

# Import our modules
from lib.routes import router
from lib.rate_limiter import limiter, rate_limit_handler
from lib.providers.model_providers import (
    SentimentModelProvider,
    NERModelProvider,
    TranslationModelProvider,
    ParaphraseModelProvider,
    SummarizationModelProvider
)
from lib.services import ParaphraseService, SentimentService, NERService, TranslationService, SummarizationService

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Get configuration from environment variables
# For Hugging Face Spaces, allow all origins by default
default_origins = "*" if os.getenv("HF_SPACE_ID") else "http://localhost:8000"
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", default_origins)
if ALLOWED_ORIGINS != "*":
    ALLOWED_ORIGINS = ALLOWED_ORIGINS.split(",")
ENVIRONMENT = os.getenv("ENVIRONMENT", "production" if os.getenv("HF_SPACE_ID") else "development")

logger.info(f"Starting application in {ENVIRONMENT} mode")
logger.info(f"Allowed CORS origins: {ALLOWED_ORIGINS}")

# Initialize FastAPI app
app = FastAPI(
    title="NLP Analysis API",
    description="A REST API for sentiment analysis, NER, translation, paraphrasing, and summarization using Hugging Face transformers",
    version="2.0.0"
)

# Add rate limiter to app state
app.state.limiter = limiter

# Add custom rate limit exception handler
app.add_exception_handler(RateLimitExceeded, rate_limit_handler)

# Add CORS middleware to allow requests from Flutter app
# SECURITY: Only allow requests from specified origins
app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,  # Controlled by environment variable
    allow_credentials=True,
    allow_methods=["GET", "POST"],  # Only allow needed HTTP methods
    allow_headers=["Content-Type", "Authorization", "X-API-Key"],  # Only allow needed headers
)

# Initialize model providers
sentiment_model = SentimentModelProvider()
ner_model = NERModelProvider()
translation_model = TranslationModelProvider()
paraphrase_model = ParaphraseModelProvider()
summarization_model = SummarizationModelProvider()

# Initialize services
sentiment_service = SentimentService(sentiment_model)
ner_service = NERService(ner_model)
translation_service = TranslationService(translation_model)
paraphrase_service = ParaphraseService(paraphrase_model)
summarization_service = SummarizationService(summarization_model)


def load_models():
    """Load all models on startup"""
    logger.info("Loading models...")
    
    # Load essential models (sentiment and NER)
    try:
        sentiment_model.load_model()
        logger.info("✓ Sentiment model loaded")
    except Exception as e:
        logger.error(f"✗ Error loading sentiment model: {e}")
        raise
    
    try:
        ner_model.load_model()
        logger.info("✓ NER model loaded")
    except Exception as e:
        logger.error(f"✗ Error loading NER model: {e}")
        raise
    
    # Load optional models (don't fail startup if these fail)
    try:
        paraphrase_model.load_model()
        logger.info("✓ Paraphrase model loaded")
    except Exception as e:
        logger.warning(f"⚠ Paraphrase model failed to load (will load on-demand): {e}")
    
    try:
        summarization_model.load_model()
        logger.info("✓ Summarization model loaded")
    except Exception as e:
        logger.warning(f"⚠ Summarization model failed to load (will load on-demand): {e}")
    
    # Translation models are loaded on-demand based on language pairs
    logger.info("Core models loaded successfully!")


# Load models on startup (non-blocking for HF Spaces health checks)
@app.on_event("startup")
async def startup_event():
    # Load models in background to allow health checks to respond quickly
    import asyncio
    asyncio.create_task(asyncio.to_thread(load_models))


# Include router
app.include_router(router)


if __name__ == "__main__":
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=8000,
        reload=True,
        log_level="info"
    )