Meshyboi commited on
Commit
361e35b
·
verified ·
1 Parent(s): 3118310

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import keras
6
  import numpy as np
7
  from transformers import RobertaTokenizer
 
8
  import uvicorn
9
 
10
  # Set Keras backend
@@ -23,6 +24,7 @@ MAX_LEN = 61
23
 
24
  # Hugging Face model repository
25
  HF_MODEL_ID = "Meshyboi/Multi-Emotion-Classification"
 
26
 
27
  # Global variables for model and tokenizer
28
  model = None
@@ -42,15 +44,26 @@ class PredictionResponse(BaseModel):
42
  detected_emotions: List[str]
43
 
44
  def load_model():
 
45
  global model
46
  try:
47
  if model is None:
48
- model = keras.saving.load_model(f"hf://{HF_MODEL_ID}")
 
 
 
 
 
 
 
 
 
49
  return model
50
  except Exception as e:
51
  raise RuntimeError(f"Error loading model: {str(e)}")
52
 
53
  def load_tokenizer():
 
54
  global tokenizer
55
  try:
56
  if tokenizer is None:
@@ -60,6 +73,7 @@ def load_tokenizer():
60
  raise RuntimeError(f"Error loading tokenizer: {str(e)}")
61
 
62
  def preprocess_text(text: str, tokenizer, max_len: int):
 
63
  encoded = tokenizer.encode_plus(
64
  text,
65
  add_special_tokens=True,
@@ -72,6 +86,7 @@ def preprocess_text(text: str, tokenizer, max_len: int):
72
  return encoded['input_ids'], encoded['attention_mask']
73
 
74
  def predict_emotions(text: str, model, tokenizer):
 
75
  input_ids, attention_mask = preprocess_text(text, tokenizer, MAX_LEN)
76
  predictions = model.predict([input_ids, attention_mask], verbose=0)
77
  return predictions[0]
@@ -86,6 +101,7 @@ async def startup_event():
86
 
87
  @app.get("/")
88
  async def root():
 
89
  return {
90
  "message": "Emotion Classification API",
91
  "version": "1.0.0",
@@ -98,6 +114,7 @@ async def root():
98
 
99
  @app.get("/health")
100
  async def health_check():
 
101
  return {
102
  "status": "healthy",
103
  "model_loaded": model is not None,
@@ -106,6 +123,14 @@ async def health_check():
106
 
107
  @app.post("/predict", response_model=PredictionResponse)
108
  async def predict(request: PredictionRequest):
 
 
 
 
 
 
 
 
109
  if not request.text.strip():
110
  raise HTTPException(status_code=400, detail="Text cannot be empty")
111
 
 
5
  import keras
6
  import numpy as np
7
  from transformers import RobertaTokenizer
8
+ from huggingface_hub import hf_hub_download
9
  import uvicorn
10
 
11
  # Set Keras backend
 
24
 
25
  # Hugging Face model repository
26
  HF_MODEL_ID = "Meshyboi/Multi-Emotion-Classification"
27
+ MODEL_FILENAME = "roberta_emotion_model.keras"
28
 
29
  # Global variables for model and tokenizer
30
  model = None
 
44
  detected_emotions: List[str]
45
 
46
  def load_model():
47
+ """Load the trained model from Hugging Face"""
48
  global model
49
  try:
50
  if model is None:
51
+ # Download the model file from Hugging Face
52
+ print(f"Downloading model file: {MODEL_FILENAME}")
53
+ model_path = hf_hub_download(
54
+ repo_id=HF_MODEL_ID,
55
+ filename=MODEL_FILENAME,
56
+ cache_dir=None # Use default cache
57
+ )
58
+ print(f"Model downloaded to: {model_path}")
59
+ # Load the model from the downloaded file
60
+ model = keras.saving.load_model(model_path)
61
  return model
62
  except Exception as e:
63
  raise RuntimeError(f"Error loading model: {str(e)}")
64
 
65
  def load_tokenizer():
66
+ """Load the tokenizer from Hugging Face"""
67
  global tokenizer
68
  try:
69
  if tokenizer is None:
 
73
  raise RuntimeError(f"Error loading tokenizer: {str(e)}")
74
 
75
  def preprocess_text(text: str, tokenizer, max_len: int):
76
+ """Preprocess text for model input"""
77
  encoded = tokenizer.encode_plus(
78
  text,
79
  add_special_tokens=True,
 
86
  return encoded['input_ids'], encoded['attention_mask']
87
 
88
  def predict_emotions(text: str, model, tokenizer):
89
+ """Predict emotions for given text"""
90
  input_ids, attention_mask = preprocess_text(text, tokenizer, MAX_LEN)
91
  predictions = model.predict([input_ids, attention_mask], verbose=0)
92
  return predictions[0]
 
101
 
102
  @app.get("/")
103
  async def root():
104
+ """Root endpoint"""
105
  return {
106
  "message": "Emotion Classification API",
107
  "version": "1.0.0",
 
114
 
115
  @app.get("/health")
116
  async def health_check():
117
+ """Health check endpoint"""
118
  return {
119
  "status": "healthy",
120
  "model_loaded": model is not None,
 
123
 
124
  @app.post("/predict", response_model=PredictionResponse)
125
  async def predict(request: PredictionRequest):
126
+ """
127
+ Predict emotions for the given text
128
+
129
+ - **text**: Input text to analyze for emotions
130
+
131
+ Returns:
132
+ - Dictionary with emotion scores and detected emotions
133
+ """
134
  if not request.text.strip():
135
  raise HTTPException(status_code=400, detail="Text cannot be empty")
136