Spaces:
Running
Running
File size: 4,186 Bytes
fe5868d 280ba73 297a341 89a9563 a8a907b 280ba73 297a341 df5115f a12b972 297a341 280ba73 df5115f 297a341 df5115f 280ba73 df5115f 297a341 df5115f 0315d54 df5115f 280ba73 df5115f 297a341 4672373 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f a516d6e df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 df5115f 280ba73 297a341 280ba73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
import pandas as pd
import torch
import tempfile
import os
import re
from collections import Counter
import datetime
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
tokenizer = DistilBertTokenizer.from_pretrained("./fine_tuned_model")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Sentiment prediction
def predict_sentiment(texts):
encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
encodings = {key: val.to(device) for key, val in encodings.items()}
with torch.no_grad():
outputs = model(**encodings)
predictions = torch.argmax(outputs.logits, dim=1)
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
return [sentiment_map[p.item()] for p in predictions]
# Frequent words
def get_top_words(texts, n=30):
all_words = []
for text in texts:
tokens = re.findall(r'\b\w{3,}\b', str(text).lower())
all_words.extend(tokens)
counter = Counter(all_words)
most_common = counter.most_common(n)
return pd.DataFrame(most_common, columns=['word', 'count'])
# Identify column
def get_text_column(df):
for col in ['content', 'tweet', 'text']:
if col in df.columns:
return col
return None
@app.api_route("/", methods=["GET", "HEAD"])
async def index(request: Request):
return JSONResponse({
"status": "ok",
"message": "Server is alive",
"timestamp": datetime.datetime.utcnow().isoformat() + "Z"
})
# POST /predict
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
df = pd.read_csv(file.file)
except Exception:
try:
file.file.seek(0)
df = pd.read_excel(file.file)
except Exception:
raise HTTPException(status_code=400, detail="Unable to read the file")
text_col = get_text_column(df)
if not text_col:
raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found')
texts = df[text_col].astype(str).tolist()
df['sentiment'] = predict_sentiment(texts)
df['content_length'] = df[text_col].astype(str).apply(len)
top_words_df = get_top_words(texts)
temp_dir = tempfile.mkdtemp()
sentiment_path = os.path.join(temp_dir, 'final_data.csv')
words_path = os.path.join(temp_dir, 'word_frequent.csv')
df.to_csv(sentiment_path, index=False)
top_words_df.to_csv(words_path, index=False)
return JSONResponse({
'sentiment_file': f'/download?file={sentiment_path}',
'top_words_file': f'/download?file={words_path}',
'sentiment_data': df.to_dict(orient='records'),
'top_words_data': top_words_df.to_dict(orient='records')
})
# POST /wordcloud
@app.post("/wordcloud")
async def wordcloud(file: UploadFile = File(...)):
try:
df = pd.read_csv(file.file)
except Exception:
try:
file.file.seek(0)
df = pd.read_excel(file.file)
except Exception:
raise HTTPException(status_code=400, detail="Unable to read the file")
text_col = get_text_column(df)
if not text_col:
raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found')
texts = df[text_col].astype(str).tolist()
top_words_df = get_top_words(texts)
return JSONResponse({'top_words_data': top_words_df.to_dict(orient='records')})
# GET /download
@app.get("/download")
async def download(file: str):
if not file or not os.path.exists(file):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file, filename=os.path.basename(file))
|