Tomiwajin commited on
Commit
df30b8a
·
verified ·
1 Parent(s): fb89cb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -62
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py - HuggingFace Space for Email Classification
2
  import gradio as gr
3
  import torch
4
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
@@ -6,6 +5,7 @@ from setfit import SetFitModel
6
  import json
7
  import logging
8
  from typing import List, Dict, Any
 
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
@@ -20,14 +20,18 @@ def load_model():
20
  """Load your trained SetFit model"""
21
  global model, classifier
22
  try:
23
- # Replace with your actual model path/name
24
  model_name = "Tomiwajin/setfit_email_classifier"
25
 
26
- # For SetFit models
27
- model = SetFitModel.from_pretrained(model_name)
28
- classifier = pipeline("text-classification", model=model.model_head, tokenizer=model.model_body.tokenizer)
29
 
 
 
 
 
 
 
30
 
 
31
  logger.info(f"Model {model_name} loaded successfully!")
32
  return True
33
  except Exception as e:
@@ -36,22 +40,24 @@ def load_model():
36
 
37
  def classify_single_email(email_text: str) -> Dict[str, Any]:
38
  """Classify a single email"""
39
- if not classifier:
40
  return {"error": "Model not loaded"}
41
 
42
  try:
43
  # Clean and truncate text
44
  email_text = email_text.strip()[:5000] # Limit length
45
 
46
- # Get prediction
47
- result = classifier(email_text)
 
48
 
49
- if isinstance(result, list):
50
- result = result[0]
 
51
 
52
  return {
53
- "label": result.get("label", "unknown"),
54
- "score": round(result.get("score", 0.0), 4),
55
  "success": True
56
  }
57
  except Exception as e:
@@ -60,15 +66,29 @@ def classify_single_email(email_text: str) -> Dict[str, Any]:
60
 
61
  def classify_batch_emails(emails: List[str]) -> List[Dict[str, Any]]:
62
  """Classify multiple emails"""
63
- if not classifier:
64
  return [{"error": "Model not loaded"}] * len(emails)
65
 
66
- results = []
67
- for email_text in emails:
68
- result = classify_single_email(email_text)
69
- results.append(result)
70
-
71
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def gradio_classify(email_text: str) -> str:
74
  """Gradio interface function"""
@@ -97,8 +117,11 @@ def api_classify_batch(emails_json: str) -> str:
97
  if not isinstance(emails, list):
98
  return json.dumps({"error": "Input must be a JSON array of strings"})
99
 
 
 
 
100
  results = classify_batch_emails(emails)
101
- return json.dumps(results, indent=2)
102
  except json.JSONDecodeError:
103
  return json.dumps({"error": "Invalid JSON format"})
104
  except Exception as e:
@@ -151,15 +174,9 @@ with gr.Blocks(title="Email Classifier", theme=gr.themes.Soft()) as demo:
151
  ```
152
 
153
  ### Batch Email Classification
154
- **POST** `/api/classify-batch`
155
  ```json
156
- {
157
- "emails": [
158
- "Email 1 content...",
159
- "Email 2 content...",
160
- "Email 3 content..."
161
- ]
162
- }
163
  ```
164
 
165
  ### Example Response
@@ -220,48 +237,17 @@ with gr.Blocks(title="Email Classifier", theme=gr.themes.Soft()) as demo:
220
  const result = await response.json();
221
 
222
  // Batch classification
223
- const batchResponse = await fetch('https://your-space-name.hf.space/api/classify-batch', {{
224
  method: 'POST',
225
  headers: {{ 'Content-Type': 'application/json' }},
226
- body: JSON.stringify({{ emails: emailArray }})
227
  }});
228
  const batchResults = await batchResponse.json();
229
  ```
230
  """)
231
 
232
- # Set up API endpoints
233
- def setup_api_routes(app):
234
- """Setup FastAPI routes for the Gradio app"""
235
- from fastapi import FastAPI, HTTPException
236
- from pydantic import BaseModel
237
-
238
- class EmailRequest(BaseModel):
239
- email_text: str
240
-
241
- class BatchEmailRequest(BaseModel):
242
- emails: List[str]
243
-
244
- @app.post("/api/classify")
245
- async def classify_endpoint(request: EmailRequest):
246
- result = classify_single_email(request.email_text)
247
- if not result.get("success", True):
248
- raise HTTPException(status_code=500, detail=result.get("error", "Classification failed"))
249
- return result
250
-
251
- @app.post("/api/classify-batch")
252
- async def classify_batch_endpoint(request: BatchEmailRequest):
253
- if len(request.emails) > 100: # Limit batch size
254
- raise HTTPException(status_code=400, detail="Maximum 100 emails per batch")
255
-
256
- results = classify_batch_emails(request.emails)
257
- return {"results": results}
258
-
259
- # Launch the app
260
  if __name__ == "__main__":
261
- # Setup API routes
262
- setup_api_routes(demo.fastapi_app)
263
-
264
- # Launch with API support
265
  demo.launch(
266
  server_name="0.0.0.0",
267
  server_port=7860,
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
 
5
  import json
6
  import logging
7
  from typing import List, Dict, Any
8
+ import os
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
 
20
  """Load your trained SetFit model"""
21
  global model, classifier
22
  try:
23
+
24
  model_name = "Tomiwajin/setfit_email_classifier"
25
 
 
 
 
26
 
27
+ token = os.getenv("HF_TOKEN")
28
+
29
+ model = SetFitModel.from_pretrained(
30
+ model_name,
31
+ use_auth_token=token if token else True
32
+ )
33
 
34
+ # Create classifier directly from SetFit model
35
  logger.info(f"Model {model_name} loaded successfully!")
36
  return True
37
  except Exception as e:
 
40
 
41
  def classify_single_email(email_text: str) -> Dict[str, Any]:
42
  """Classify a single email"""
43
+ if not model:
44
  return {"error": "Model not loaded"}
45
 
46
  try:
47
  # Clean and truncate text
48
  email_text = email_text.strip()[:5000] # Limit length
49
 
50
+ # Get prediction using SetFit model directly
51
+ predictions = model.predict([email_text])
52
+ probabilities = model.predict_proba([email_text])[0] # Get probabilities for first (and only) sample
53
 
54
+ # Get the predicted label and confidence
55
+ predicted_label = predictions[0]
56
+ confidence = max(probabilities) # Confidence is the max probability
57
 
58
  return {
59
+ "label": str(predicted_label),
60
+ "score": round(float(confidence), 4),
61
  "success": True
62
  }
63
  except Exception as e:
 
66
 
67
  def classify_batch_emails(emails: List[str]) -> List[Dict[str, Any]]:
68
  """Classify multiple emails"""
69
+ if not model:
70
  return [{"error": "Model not loaded"}] * len(emails)
71
 
72
+ try:
73
+ # Clean and truncate texts
74
+ cleaned_emails = [email.strip()[:5000] for email in emails]
75
+
76
+ # Get batch predictions
77
+ predictions = model.predict(cleaned_emails)
78
+ probabilities = model.predict_proba(cleaned_emails)
79
+
80
+ results = []
81
+ for i, (pred, probs) in enumerate(zip(predictions, probabilities)):
82
+ results.append({
83
+ "label": str(pred),
84
+ "score": round(float(max(probs)), 4),
85
+ "success": True
86
+ })
87
+
88
+ return results
89
+ except Exception as e:
90
+ logger.error(f"Batch classification error: {e}")
91
+ return [{"error": str(e), "success": False}] * len(emails)
92
 
93
  def gradio_classify(email_text: str) -> str:
94
  """Gradio interface function"""
 
117
  if not isinstance(emails, list):
118
  return json.dumps({"error": "Input must be a JSON array of strings"})
119
 
120
+ if len(emails) > 100: # Limit batch size
121
+ return json.dumps({"error": "Maximum 100 emails per batch"})
122
+
123
  results = classify_batch_emails(emails)
124
+ return json.dumps({"results": results}, indent=2)
125
  except json.JSONDecodeError:
126
  return json.dumps({"error": "Invalid JSON format"})
127
  except Exception as e:
 
174
  ```
175
 
176
  ### Batch Email Classification
177
+ **POST** `/api/classify_batch`
178
  ```json
179
+ ["Email 1 content...", "Email 2 content...", "Email 3 content..."]
 
 
 
 
 
 
180
  ```
181
 
182
  ### Example Response
 
237
  const result = await response.json();
238
 
239
  // Batch classification
240
+ const batchResponse = await fetch('https://your-space-name.hf.space/api/classify_batch', {{
241
  method: 'POST',
242
  headers: {{ 'Content-Type': 'application/json' }},
243
+ body: JSON.stringify(emailArray)
244
  }});
245
  const batchResults = await batchResponse.json();
246
  ```
247
  """)
248
 
249
+ # Launch the app with API endpoints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if __name__ == "__main__":
 
 
 
 
251
  demo.launch(
252
  server_name="0.0.0.0",
253
  server_port=7860,