from fastapi import FastAPI from pydantic import BaseModel from typing import Dict, Optional, List import numpy as np # Untuk operasi np.argmax import os # Untuk os.environ (Bearer Token) # Pustaka ML from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # Import dari file lokal from api_client import get_tweets_by_username from preprocessing import preprocess_text app = FastAPI() # --- KONFIGURASI PATH --- MODEL_DIR = "./model_assets" WEIGHTS_FILE = "./model_assets/best_indobertweet.pth" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- PEMUATAN MODEL --- # Logika 3 Tahap: 1. Muat Tokenizer, 2. Muat Struktur Model, 3. Muat Bobot .pth try: print("Mencoba memuat model...") # 1. Muat Tokenizer & Struktur tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, num_labels=2) # 2. Muat Bobot .pth # Setting map_location ke 'cpu' memastikan bisa dimuat bahkan jika server tidak punya GPU state_dict = torch.load(WEIGHTS_FILE, map_location=DEVICE) model.load_state_dict(state_dict) # 3. Finalisasi model.to(DEVICE) model.eval() print(f"Model berhasil dimuat ke device: {DEVICE}") except Exception as e: print(f"FATAL ERROR: Gagal memuat model. Pastikan file di {MODEL_DIR} sudah benar.") print(e) # Anda bisa memilih untuk raise error di sini untuk menghentikan server jika model gagal dimuat. # --- DEFINISI PYDANTIC MODELS --- # 1. Skema Permintaan (DATA YANG DITERIMA DARI Express.js) class StressRequest(BaseModel): x_username: str tweet_count: int = 100 # 2. Skema Data Hasil (DATA YANG DISIMPAN KE DB & DIKIRIM KE Frontend) class ResultData(BaseModel): x_username: str total_tweets: int stress_level: int # Skor 0-100 (Probabilitas positif dikalikan 100) keywords: Dict[str, float] # Contoh Placeholder: Tren Kata stress_status: int # 0: Aman, 1: Rendah, 2: Sedang, 3: Tinggi # 3. Skema Respons API Akhir class APIResponse(BaseModel): message: str data: Optional[ResultData] # --- UTILITY FUNCTIONS --- def calculate_stress_status(stress_level: float) -> int: """Mengkonversi skor probabilitas (0-100) menjadi status diskrit.""" if stress_level >= 75: return 3 # Tinggi elif stress_level >= 50: return 2 # Sedang elif stress_level >= 25: return 1 # Rendah else: return 0 # Aman # --- ENDPOINT UTAMA --- @app.post("/api/predict_stress", response_model=APIResponse) def predict_stress(request: StressRequest): username = request.x_username tweet_count = request.tweet_count # Cek Bearer Token sebelum memanggil API Twitter if not os.environ.get("TWITTER_BEARER_TOKEN"): return APIResponse( message="Error: TWITTER_BEARER_TOKEN tidak diatur sebagai environment variable.", data=None ) # 1. Ambil Tweet raw_tweets = get_tweets_by_username(username, tweet_count) if not raw_tweets: return APIResponse( message=f"Gagal mengambil tweet dari @{username}. Akun mungkin private atau tidak ditemukan.", data=None ) # 2. Pre-processing dan Inferensi cleaned_texts = [preprocess_text(t) for t in raw_tweets] # Inisialisasi list untuk menyimpan probabilitas stress (kelas 1) stress_probabilities = [] with torch.no_grad(): for text in cleaned_texts: # Tokenisasi enc = tokenizer( text, truncation=True, padding="max_length", max_length=128, # Konsisten dengan training return_tensors="pt" ) # Pindahkan ke device dan inferensi input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits # Ambil probabilitas untuk kelas 1 (Stress) probs = torch.softmax(logits, dim=1).cpu().numpy()[0] stress_probabilities.append(probs[1]) # 3. Agregasi Hasil if not stress_probabilities: # Jika ada tweets tapi semuanya kosong setelah pre-processing avg_stress_score = 0 else: # Hitung rata-rata probabilitas stres dari semua tweet (skor 0.0 - 1.0) avg_stress_score = np.mean(stress_probabilities) # Konversi ke skala 0-100 stress_level_100 = int(round(avg_stress_score * 100)) status = calculate_stress_status(stress_level_100) # 4. Susun Respons result_data = ResultData( x_username=username, total_tweets=len(raw_tweets), stress_level=stress_level_100, keywords={"placeholder": 0.0}, # Implementasi penambangan keyword akan dilakukan belakangan stress_status=status ) return APIResponse( message=f"Analisis stres untuk @{username} berhasil. Ditemukan {result_data.total_tweets} tweets.", data=result_data )