scam-detector / app.py
jerrynnms's picture
Update app.py
768506f verified
import os
# ✅ Hugging Face 建議路徑(防止 cache 錯誤)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
from bert_explainer import analyze_text as bert_analyze_text
from firebase_admin import credentials, firestore
import firebase_admin
import pytz
import json
import requests
import torch
# ✅ 初始化 FastAPI
app = FastAPI(
title="詐騙訊息辨識 API",
description="使用 BERT 模型分析輸入文字是否為詐騙內容",
version="1.0.0"
)
# ✅ 跨域處理
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ✅ 掛載靜態檔案:支援 script.js / style.css
app.mount("/static", StaticFiles(directory="."), name="static")
# ✅ 回傳首頁 index.html
@app.get("/", response_class=FileResponse)
async def serve_index():
return FileResponse("index.html")
# ✅ Firebase 初始化
try:
cred_data = os.getenv("FIREBASE_CREDENTIALS")
if not cred_data:
raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置")
cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)})
firebase_admin.initialize_app(cred)
db = firestore.client()
except Exception as e:
print(f"Firebase 初始化錯誤: {e}")
# ✅ 從 Hugging Face Hub 載入模型(改為 /tmp)
def load_model_from_hub():
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
model_path = "/tmp/model.pth"
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
return model_path
raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth")
model_path = "/tmp/model.pth"
if not os.path.exists(model_path):
model_path = load_model_from_hub()
from AI_Model_architecture import BertLSTM_CNN_Classifier
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# ✅ 資料格式
class TextAnalysisRequest(BaseModel):
text: str
user_id: Optional[str] = None
class TextAnalysisResponse(BaseModel):
status: str
confidence: float
suspicious_keywords: List[str]
analysis_timestamp: datetime
text_id: str
# ✅ /predict API
@app.post("/predict", response_model=TextAnalysisResponse)
async def analyze_text_api(request: TextAnalysisRequest):
try:
tz = pytz.timezone("Asia/Taipei")
now = datetime.now(tz)
doc_id = now.strftime("%Y%m%dT%H%M%S")
date_str = now.strftime("%Y-%m-%d %H:%M:%S")
collection = now.strftime("%Y%m%d")
result = bert_analyze_text(request.text)
record = {
"text_id": doc_id,
"text": request.text,
"user_id": request.user_id,
"analysis_result": result,
"timestamp": date_str,
"type": "text_analysis"
}
db.collection(collection).document(doc_id).set(record)
return TextAnalysisResponse(
status=result["status"],
confidence=result["confidence"],
suspicious_keywords=result["suspicious_keywords"],
analysis_timestamp=now,
text_id=doc_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ✅ /feedback API
@app.post("/feedback")
async def save_user_feedback(feedback: dict):
try:
tz = pytz.timezone("Asia/Taipei")
timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
feedback["timestamp"] = timestamp_str
feedback["used_in_training"] = False
db.collection("user_feedback").add(feedback)
return {"message": "✅ 已記錄使用者回饋"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))