Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from transformers import DistilBertForSequenceClassification, DistilBertTokenizer | |
| import pandas as pd | |
| import torch | |
| import tempfile | |
| import os | |
| import re | |
| from collections import Counter | |
| import datetime | |
| app = FastAPI() | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load model | |
| model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model") | |
| tokenizer = DistilBertTokenizer.from_pretrained("./fine_tuned_model") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| # Sentiment prediction | |
| def predict_sentiment(texts): | |
| encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| encodings = {key: val.to(device) for key, val in encodings.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encodings) | |
| predictions = torch.argmax(outputs.logits, dim=1) | |
| sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"} | |
| return [sentiment_map[p.item()] for p in predictions] | |
| # Frequent words | |
| def get_top_words(texts, n=30): | |
| all_words = [] | |
| for text in texts: | |
| tokens = re.findall(r'\b\w{3,}\b', str(text).lower()) | |
| all_words.extend(tokens) | |
| counter = Counter(all_words) | |
| most_common = counter.most_common(n) | |
| return pd.DataFrame(most_common, columns=['word', 'count']) | |
| # Identify column | |
| def get_text_column(df): | |
| for col in ['content', 'tweet', 'text']: | |
| if col in df.columns: | |
| return col | |
| return None | |
| async def index(request: Request): | |
| return JSONResponse({ | |
| "status": "ok", | |
| "message": "Server is alive", | |
| "timestamp": datetime.datetime.utcnow().isoformat() + "Z" | |
| }) | |
| # POST /predict | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| df = pd.read_csv(file.file) | |
| except Exception: | |
| try: | |
| file.file.seek(0) | |
| df = pd.read_excel(file.file) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Unable to read the file") | |
| text_col = get_text_column(df) | |
| if not text_col: | |
| raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found') | |
| texts = df[text_col].astype(str).tolist() | |
| df['sentiment'] = predict_sentiment(texts) | |
| df['content_length'] = df[text_col].astype(str).apply(len) | |
| top_words_df = get_top_words(texts) | |
| temp_dir = tempfile.mkdtemp() | |
| sentiment_path = os.path.join(temp_dir, 'final_data.csv') | |
| words_path = os.path.join(temp_dir, 'word_frequent.csv') | |
| df.to_csv(sentiment_path, index=False) | |
| top_words_df.to_csv(words_path, index=False) | |
| return JSONResponse({ | |
| 'sentiment_file': f'/download?file={sentiment_path}', | |
| 'top_words_file': f'/download?file={words_path}', | |
| 'sentiment_data': df.to_dict(orient='records'), | |
| 'top_words_data': top_words_df.to_dict(orient='records') | |
| }) | |
| # POST /wordcloud | |
| async def wordcloud(file: UploadFile = File(...)): | |
| try: | |
| df = pd.read_csv(file.file) | |
| except Exception: | |
| try: | |
| file.file.seek(0) | |
| df = pd.read_excel(file.file) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Unable to read the file") | |
| text_col = get_text_column(df) | |
| if not text_col: | |
| raise HTTPException(status_code=400, detail='No "content", "tweet", or "text" column found') | |
| texts = df[text_col].astype(str).tolist() | |
| top_words_df = get_top_words(texts) | |
| return JSONResponse({'top_words_data': top_words_df.to_dict(orient='records')}) | |
| # GET /download | |
| async def download(file: str): | |
| if not file or not os.path.exists(file): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse(file, filename=os.path.basename(file)) | |