AbdoIR's picture
Update api.py
fe5868d verified
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
import pandas as pd
import torch
import tempfile
import os
import re
from collections import Counter
import datetime
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
tokenizer = DistilBertTokenizer.from_pretrained("./fine_tuned_model")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Sentiment prediction
def predict_sentiment(texts):
encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
encodings = {key: val.to(device) for key, val in encodings.items()}
with torch.no_grad():
outputs = model(**encodings)
predictions = torch.argmax(outputs.logits, dim=1)
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
return [sentiment_map[p.item()] for p in predictions]
# Frequent words
def get_top_words(texts, n=30):
all_words = []
for text in texts:
tokens = re.findall(r'\b\w{3,}\b', str(text).lower())
all_words.extend(tokens)
counter = Counter(all_words)
most_common = counter.most_common(n)
return pd.DataFrame(most_common, columns=['word', 'count'])
# Identify column
def get_text_column(df):
for col in ['content', 'tweet', 'text']:
if col in df.columns:
return col
return None
@app.api_route("/", methods=["GET", "HEAD"])
async def index(request: Request):
return JSONResponse({
"status": "ok",
"message": "Server is alive",
"timestamp": datetime.datetime.utcnow().isoformat() + "Z"
})
# POST /predict
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
df = pd.read_csv(file.file)
except Exception:
try:
file.file.seek(0)
df = pd.read_excel(file.file)
except Exception:
raise HTTPException(status_code=400, detail="Unable to read the file")
text_col = get_text_column(df)
if not text_col:
raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found')
texts = df[text_col].astype(str).tolist()
df['sentiment'] = predict_sentiment(texts)
df['content_length'] = df[text_col].astype(str).apply(len)
top_words_df = get_top_words(texts)
temp_dir = tempfile.mkdtemp()
sentiment_path = os.path.join(temp_dir, 'final_data.csv')
words_path = os.path.join(temp_dir, 'word_frequent.csv')
df.to_csv(sentiment_path, index=False)
top_words_df.to_csv(words_path, index=False)
return JSONResponse({
'sentiment_file': f'/download?file={sentiment_path}',
'top_words_file': f'/download?file={words_path}',
'sentiment_data': df.to_dict(orient='records'),
'top_words_data': top_words_df.to_dict(orient='records')
})
# POST /wordcloud
@app.post("/wordcloud")
async def wordcloud(file: UploadFile = File(...)):
try:
df = pd.read_csv(file.file)
except Exception:
try:
file.file.seek(0)
df = pd.read_excel(file.file)
except Exception:
raise HTTPException(status_code=400, detail="Unable to read the file")
text_col = get_text_column(df)
if not text_col:
raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found')
texts = df[text_col].astype(str).tolist()
top_words_df = get_top_words(texts)
return JSONResponse({'top_words_data': top_words_df.to_dict(orient='records')})
# GET /download
@app.get("/download")
async def download(file: str):
if not file or not os.path.exists(file):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file, filename=os.path.basename(file))