m93 commited on
Commit
986654d
·
verified ·
1 Parent(s): 12e8a63

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - Deploy this to Hugging Face Spaces
2
+ # Install: pip install fastapi uvicorn torch transformers huggingface_hub
3
+
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from fastapi import FastAPI, HTTPException
11
+ from huggingface_hub import hf_hub_download
12
+ from pydantic import BaseModel
13
+ from transformers import AutoModel, AutoTokenizer
14
+
15
+ app = FastAPI(title="Sentiment Analysis API")
16
+
17
+ # Global variables for lazy loading
18
+ model = None
19
+ tokenizer = None
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+
23
+ # Model definition (must match training code)
24
+ class SentimentClassifier(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.bert = AutoModel.from_pretrained("distilbert-base-uncased")
28
+ self.dropout = nn.Dropout(0.3)
29
+ self.classifier = nn.Linear(768, 2)
30
+
31
+ def forward(self, input_ids, attention_mask, **kwargs):
32
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled = outputs.last_hidden_state[:, 0]
34
+ x = self.dropout(pooled)
35
+ return self.classifier(x)
36
+
37
+
38
+ # Request/Response models
39
+ class PredictionRequest(BaseModel):
40
+ text: str
41
+
42
+
43
+ class PredictionResponse(BaseModel):
44
+ sentiment: str
45
+ confidence: float
46
+
47
+
48
+ def load_model_from_hf(repo_id: str):
49
+ """Load model from Hugging Face on-demand"""
50
+ global model, tokenizer
51
+
52
+ if model is not None:
53
+ return # Already loaded
54
+
55
+ print(f"📥 Loading model from {repo_id}...")
56
+
57
+ # Download model files
58
+ cache_dir = "./model_cache"
59
+ Path(cache_dir).mkdir(exist_ok=True)
60
+
61
+ model_path = hf_hub_download(
62
+ repo_id=repo_id, filename="model.pt", cache_dir=cache_dir
63
+ )
64
+
65
+ config_path = hf_hub_download(
66
+ repo_id=repo_id, filename="config.json", cache_dir=cache_dir
67
+ )
68
+
69
+ # Load tokenizer
70
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, cache_dir=cache_dir)
71
+
72
+ # Load model
73
+ model = SentimentClassifier()
74
+ model.load_state_dict(torch.load(model_path, map_location=device))
75
+ model.to(device)
76
+ model.eval()
77
+
78
+ print(f"✅ Model loaded successfully on {device}")
79
+
80
+
81
+ @app.on_event("startup")
82
+ async def startup_event():
83
+ """Load model when server starts"""
84
+ # Read from environment variable or use default
85
+ REPO_ID = os.environ.get("MODEL_REPO_ID", "m93/sentiment-model")
86
+ load_model_from_hf(REPO_ID)
87
+
88
+
89
+ @app.get("/")
90
+ def root():
91
+ return {
92
+ "message": "Sentiment Analysis API",
93
+ "status": "running",
94
+ "endpoints": {
95
+ "/predict": "POST - Analyze sentiment of text",
96
+ "/health": "GET - Check if model is loaded",
97
+ "/docs": "GET - Interactive API documentation",
98
+ },
99
+ }
100
+
101
+
102
+ @app.get("/health")
103
+ def health_check():
104
+ return {
105
+ "status": "healthy",
106
+ "model_loaded": model is not None,
107
+ "device": str(device),
108
+ }
109
+
110
+
111
+ @app.post("/predict", response_model=PredictionResponse)
112
+ def predict(request: PredictionRequest):
113
+ if model is None or tokenizer is None:
114
+ raise HTTPException(status_code=503, detail="Model not loaded")
115
+
116
+ try:
117
+ # Tokenize input
118
+ inputs = tokenizer(
119
+ request.text,
120
+ return_tensors="pt",
121
+ padding=True,
122
+ truncation=True,
123
+ max_length=512,
124
+ )
125
+ inputs = {k: v.to(device) for k, v in inputs.items()}
126
+
127
+ # Get prediction
128
+ with torch.no_grad():
129
+ outputs = model(**inputs)
130
+ probs = torch.softmax(outputs, dim=1)
131
+ prediction = torch.argmax(probs, dim=1).item()
132
+ confidence = probs[0][prediction].item()
133
+
134
+ sentiment = "positive" if prediction == 1 else "negative"
135
+
136
+ return PredictionResponse(sentiment=sentiment, confidence=round(confidence, 4))
137
+
138
+ except Exception as e:
139
+ raise HTTPException(status_code=500, detail=str(e))
140
+
141
+
142
+ if __name__ == "__main__":
143
+ import uvicorn
144
+
145
+ port = int(os.environ.get("PORT", 7860)) # HF Spaces uses port 7860
146
+ print("🚀 Starting API server...")
147
+ uvicorn.run(app, host="0.0.0.0", port=port)