AbdoIR commited on
Commit
280ba73
·
verified ·
1 Parent(s): 0315d54

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +45 -40
api.py CHANGED
@@ -1,4 +1,6 @@
1
- from flask import Flask, request, send_file, jsonify
 
 
2
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
3
  import pandas as pd
4
  import torch
@@ -6,10 +8,17 @@ import tempfile
6
  import os
7
  import re
8
  from collections import Counter
9
- from flask_cors import CORS
10
 
11
- app = Flask(__name__)
12
- CORS(app)
 
 
 
 
 
 
 
 
13
 
14
  # Load model
15
  model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
@@ -19,6 +28,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
  model.eval()
21
 
 
22
  # Sentiment prediction
23
  def predict_sentiment(texts):
24
  encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
@@ -29,6 +39,7 @@ def predict_sentiment(texts):
29
  sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
30
  return [sentiment_map[p.item()] for p in predictions]
31
 
 
32
  # Frequent words
33
  def get_top_words(texts, n=30):
34
  all_words = []
@@ -39,6 +50,7 @@ def get_top_words(texts, n=30):
39
  most_common = counter.most_common(n)
40
  return pd.DataFrame(most_common, columns=['word', 'count'])
41
 
 
42
  # Identify column
43
  def get_text_column(df):
44
  for col in ['content', 'tweet', 'text']:
@@ -46,28 +58,24 @@ def get_text_column(df):
46
  return col
47
  return None
48
 
49
- # POST /predict
50
- @app.route('/predict', methods=['POST'])
51
- def predict():
52
- if 'file' not in request.files:
53
- return jsonify({'error': 'No file uploaded'}), 400
54
 
55
- file = request.files['file']
 
 
56
  try:
57
- df = pd.read_csv(file)
58
  except Exception:
59
  try:
60
- file.seek(0)
61
- df = pd.read_excel(file)
62
  except Exception:
63
- return jsonify({'error': 'Unable to read the file'}), 400
64
 
65
  text_col = get_text_column(df)
66
  if not text_col:
67
- return jsonify({'error': 'No "content", "tweet", or "text" column found'}), 400
68
 
69
  texts = df[text_col].astype(str).tolist()
70
-
71
  df['sentiment'] = predict_sentiment(texts)
72
  df['content_length'] = df[text_col].astype(str).apply(len)
73
 
@@ -75,50 +83,47 @@ def predict():
75
 
76
  temp_dir = tempfile.mkdtemp()
77
  sentiment_path = os.path.join(temp_dir, 'final_data.csv')
78
- df.to_csv(sentiment_path, index=False)
79
-
80
  words_path = os.path.join(temp_dir, 'word_frequent.csv')
 
 
81
  top_words_df.to_csv(words_path, index=False)
82
 
83
- return jsonify({
84
  'sentiment_file': f'/download?file={sentiment_path}',
85
  'top_words_file': f'/download?file={words_path}',
86
  'sentiment_data': df.to_dict(orient='records'),
87
  'top_words_data': top_words_df.to_dict(orient='records')
88
  })
89
 
90
- # POST /wordcloud
91
- @app.route('/wordcloud', methods=['POST'])
92
- def wordcloud():
93
- if 'file' not in request.files:
94
- return jsonify({'error': 'No file uploaded'}), 400
95
 
96
- file = request.files['file']
 
 
97
  try:
98
- df = pd.read_csv(file)
99
  except Exception:
100
  try:
101
- file.seek(0)
102
- df = pd.read_excel(file)
103
  except Exception:
104
- return jsonify({'error': 'Unable to read the file'}), 400
105
 
106
  text_col = get_text_column(df)
107
  if not text_col:
108
- return jsonify({'error': 'No "content", "tweet", or "text" column found'}), 400
109
 
110
  texts = df[text_col].astype(str).tolist()
111
  top_words_df = get_top_words(texts)
112
 
113
- return jsonify({'top_words_data': top_words_df.to_dict(orient='records')})
 
114
 
115
  # GET /download
116
- @app.route('/download')
117
- def download():
118
- file_path = request.args.get('file')
119
- if not file_path or not os.path.exists(file_path):
120
- return jsonify({'error': 'File not found'}), 404
121
- return send_file(file_path, as_attachment=True)
122
-
123
- if __name__ == '__main__':
124
- app.run(host="0.0.0.0", port=7860, debug=True)
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import FileResponse, JSONResponse
4
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
5
  import pandas as pd
6
  import torch
 
8
  import os
9
  import re
10
  from collections import Counter
 
11
 
12
+ app = FastAPI()
13
+
14
+ # Enable CORS
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
 
23
  # Load model
24
  model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
 
28
  model.to(device)
29
  model.eval()
30
 
31
+
32
  # Sentiment prediction
33
  def predict_sentiment(texts):
34
  encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
 
39
  sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
40
  return [sentiment_map[p.item()] for p in predictions]
41
 
42
+
43
  # Frequent words
44
  def get_top_words(texts, n=30):
45
  all_words = []
 
50
  most_common = counter.most_common(n)
51
  return pd.DataFrame(most_common, columns=['word', 'count'])
52
 
53
+
54
  # Identify column
55
  def get_text_column(df):
56
  for col in ['content', 'tweet', 'text']:
 
58
  return col
59
  return None
60
 
 
 
 
 
 
61
 
62
+ # POST /predict
63
+ @app.post("/predict")
64
+ async def predict(file: UploadFile = File(...)):
65
  try:
66
+ df = pd.read_csv(file.file)
67
  except Exception:
68
  try:
69
+ file.file.seek(0)
70
+ df = pd.read_excel(file.file)
71
  except Exception:
72
+ raise HTTPException(status_code=400, detail="Unable to read the file")
73
 
74
  text_col = get_text_column(df)
75
  if not text_col:
76
+ raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found')
77
 
78
  texts = df[text_col].astype(str).tolist()
 
79
  df['sentiment'] = predict_sentiment(texts)
80
  df['content_length'] = df[text_col].astype(str).apply(len)
81
 
 
83
 
84
  temp_dir = tempfile.mkdtemp()
85
  sentiment_path = os.path.join(temp_dir, 'final_data.csv')
 
 
86
  words_path = os.path.join(temp_dir, 'word_frequent.csv')
87
+
88
+ df.to_csv(sentiment_path, index=False)
89
  top_words_df.to_csv(words_path, index=False)
90
 
91
+ return JSONResponse({
92
  'sentiment_file': f'/download?file={sentiment_path}',
93
  'top_words_file': f'/download?file={words_path}',
94
  'sentiment_data': df.to_dict(orient='records'),
95
  'top_words_data': top_words_df.to_dict(orient='records')
96
  })
97
 
 
 
 
 
 
98
 
99
+ # POST /wordcloud
100
+ @app.post("/wordcloud")
101
+ async def wordcloud(file: UploadFile = File(...)):
102
  try:
103
+ df = pd.read_csv(file.file)
104
  except Exception:
105
  try:
106
+ file.file.seek(0)
107
+ df = pd.read_excel(file.file)
108
  except Exception:
109
+ raise HTTPException(status_code=400, detail="Unable to read the file")
110
 
111
  text_col = get_text_column(df)
112
  if not text_col:
113
+ raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found')
114
 
115
  texts = df[text_col].astype(str).tolist()
116
  top_words_df = get_top_words(texts)
117
 
118
+ return JSONResponse({'top_words_data': top_words_df.to_dict(orient='records')})
119
+
120
 
121
  # GET /download
122
+ @app.get("/download")
123
+ async def download(file: str):
124
+ if not file or not os.path.exists(file):
125
+ raise HTTPException(status_code=404, detail="File not found")
126
+ return FileResponse(file, filename=os.path.basename(file))
127
+
128
+
129
+ # Run with: uvicorn main:app --host 0.0.0.0 --port 7860