gaidasalsaa's picture
Fix: Corrected absolute imports in main.py for Docker
46465cb
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Dict, Optional, List
import numpy as np # Untuk operasi np.argmax
import os # Untuk os.environ (Bearer Token)
# Pustaka ML
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Import dari file lokal
from api_client import get_tweets_by_username
from preprocessing import preprocess_text
app = FastAPI()
# --- KONFIGURASI PATH ---
MODEL_DIR = "./model_assets"
WEIGHTS_FILE = "./model_assets/best_indobertweet.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- PEMUATAN MODEL ---
# Logika 3 Tahap: 1. Muat Tokenizer, 2. Muat Struktur Model, 3. Muat Bobot .pth
try:
print("Mencoba memuat model...")
# 1. Muat Tokenizer & Struktur
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, num_labels=2)
# 2. Muat Bobot .pth
# Setting map_location ke 'cpu' memastikan bisa dimuat bahkan jika server tidak punya GPU
state_dict = torch.load(WEIGHTS_FILE, map_location=DEVICE)
model.load_state_dict(state_dict)
# 3. Finalisasi
model.to(DEVICE)
model.eval()
print(f"Model berhasil dimuat ke device: {DEVICE}")
except Exception as e:
print(f"FATAL ERROR: Gagal memuat model. Pastikan file di {MODEL_DIR} sudah benar.")
print(e)
# Anda bisa memilih untuk raise error di sini untuk menghentikan server jika model gagal dimuat.
# --- DEFINISI PYDANTIC MODELS ---
# 1. Skema Permintaan (DATA YANG DITERIMA DARI Express.js)
class StressRequest(BaseModel):
x_username: str
tweet_count: int = 100
# 2. Skema Data Hasil (DATA YANG DISIMPAN KE DB & DIKIRIM KE Frontend)
class ResultData(BaseModel):
x_username: str
total_tweets: int
stress_level: int # Skor 0-100 (Probabilitas positif dikalikan 100)
keywords: Dict[str, float] # Contoh Placeholder: Tren Kata
stress_status: int # 0: Aman, 1: Rendah, 2: Sedang, 3: Tinggi
# 3. Skema Respons API Akhir
class APIResponse(BaseModel):
message: str
data: Optional[ResultData]
# --- UTILITY FUNCTIONS ---
def calculate_stress_status(stress_level: float) -> int:
"""Mengkonversi skor probabilitas (0-100) menjadi status diskrit."""
if stress_level >= 75:
return 3 # Tinggi
elif stress_level >= 50:
return 2 # Sedang
elif stress_level >= 25:
return 1 # Rendah
else:
return 0 # Aman
# --- ENDPOINT UTAMA ---
@app.post("/api/predict_stress", response_model=APIResponse)
def predict_stress(request: StressRequest):
username = request.x_username
tweet_count = request.tweet_count
# Cek Bearer Token sebelum memanggil API Twitter
if not os.environ.get("TWITTER_BEARER_TOKEN"):
return APIResponse(
message="Error: TWITTER_BEARER_TOKEN tidak diatur sebagai environment variable.",
data=None
)
# 1. Ambil Tweet
raw_tweets = get_tweets_by_username(username, tweet_count)
if not raw_tweets:
return APIResponse(
message=f"Gagal mengambil tweet dari @{username}. Akun mungkin private atau tidak ditemukan.",
data=None
)
# 2. Pre-processing dan Inferensi
cleaned_texts = [preprocess_text(t) for t in raw_tweets]
# Inisialisasi list untuk menyimpan probabilitas stress (kelas 1)
stress_probabilities = []
with torch.no_grad():
for text in cleaned_texts:
# Tokenisasi
enc = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128, # Konsisten dengan training
return_tensors="pt"
)
# Pindahkan ke device dan inferensi
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# Ambil probabilitas untuk kelas 1 (Stress)
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
stress_probabilities.append(probs[1])
# 3. Agregasi Hasil
if not stress_probabilities:
# Jika ada tweets tapi semuanya kosong setelah pre-processing
avg_stress_score = 0
else:
# Hitung rata-rata probabilitas stres dari semua tweet (skor 0.0 - 1.0)
avg_stress_score = np.mean(stress_probabilities)
# Konversi ke skala 0-100
stress_level_100 = int(round(avg_stress_score * 100))
status = calculate_stress_status(stress_level_100)
# 4. Susun Respons
result_data = ResultData(
x_username=username,
total_tweets=len(raw_tweets),
stress_level=stress_level_100,
keywords={"placeholder": 0.0}, # Implementasi penambangan keyword akan dilakukan belakangan
stress_status=status
)
return APIResponse(
message=f"Analisis stres untuk @{username} berhasil. Ditemukan {result_data.total_tweets} tweets.",
data=result_data
)