File size: 4,252 Bytes
546d7be
 
768506f
87caf00
 
 
546d7be
 
 
768506f
 
546d7be
 
 
 
 
 
 
 
 
 
 
768506f
546d7be
 
 
 
 
 
768506f
546d7be
 
 
 
 
 
 
 
768506f
 
546d7be
768506f
 
 
 
546d7be
768506f
546d7be
 
 
 
 
 
 
 
 
 
768506f
 
1b7b655
768506f
546d7be
 
768506f
546d7be
768506f
 
546d7be
768506f
 
 
546d7be
 
 
768506f
546d7be
 
768506f
 
 
 
 
 
 
 
 
 
 
546d7be
768506f
546d7be
 
 
 
768506f
 
 
 
546d7be
 
 
 
768506f
546d7be
 
768506f
 
546d7be
 
 
768506f
546d7be
 
 
 
 
768506f
 
546d7be
 
 
768506f
 
546d7be
 
 
 
768506f
546d7be
768506f
546d7be
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os

# ✅ Hugging Face 建議路徑(防止 cache 錯誤)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
from bert_explainer import analyze_text as bert_analyze_text
from firebase_admin import credentials, firestore
import firebase_admin
import pytz
import json
import requests
import torch

# ✅ 初始化 FastAPI
app = FastAPI(
    title="詐騙訊息辨識 API",
    description="使用 BERT 模型分析輸入文字是否為詐騙內容",
    version="1.0.0"
)

# ✅ 跨域處理
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ✅ 掛載靜態檔案:支援 script.js / style.css
app.mount("/static", StaticFiles(directory="."), name="static")

# ✅ 回傳首頁 index.html
@app.get("/", response_class=FileResponse)
async def serve_index():
    return FileResponse("index.html")

# ✅ Firebase 初始化
try:
    cred_data = os.getenv("FIREBASE_CREDENTIALS")
    if not cred_data:
        raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置")
    cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)})
    firebase_admin.initialize_app(cred)
    db = firestore.client()
except Exception as e:
    print(f"Firebase 初始化錯誤: {e}")

# ✅ 從 Hugging Face Hub 載入模型(改為 /tmp)
def load_model_from_hub():
    model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
    model_path = "/tmp/model.pth"
    response = requests.get(model_url)
    if response.status_code == 200:
        with open(model_path, "wb") as f:
            f.write(response.content)
        return model_path
    raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth")

model_path = "/tmp/model.pth"
if not os.path.exists(model_path):
    model_path = load_model_from_hub()

from AI_Model_architecture import BertLSTM_CNN_Classifier
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

# ✅ 資料格式
class TextAnalysisRequest(BaseModel):
    text: str
    user_id: Optional[str] = None

class TextAnalysisResponse(BaseModel):
    status: str
    confidence: float
    suspicious_keywords: List[str]
    analysis_timestamp: datetime
    text_id: str

# ✅ /predict API
@app.post("/predict", response_model=TextAnalysisResponse)
async def analyze_text_api(request: TextAnalysisRequest):
    try:
        tz = pytz.timezone("Asia/Taipei")
        now = datetime.now(tz)
        doc_id = now.strftime("%Y%m%dT%H%M%S")
        date_str = now.strftime("%Y-%m-%d %H:%M:%S")
        collection = now.strftime("%Y%m%d")

        result = bert_analyze_text(request.text)

        record = {
            "text_id": doc_id,
            "text": request.text,
            "user_id": request.user_id,
            "analysis_result": result,
            "timestamp": date_str,
            "type": "text_analysis"
        }

        db.collection(collection).document(doc_id).set(record)

        return TextAnalysisResponse(
            status=result["status"],
            confidence=result["confidence"],
            suspicious_keywords=result["suspicious_keywords"],
            analysis_timestamp=now,
            text_id=doc_id
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# ✅ /feedback API
@app.post("/feedback")
async def save_user_feedback(feedback: dict):
    try:
        tz = pytz.timezone("Asia/Taipei")
        timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
        feedback["timestamp"] = timestamp_str
        feedback["used_in_training"] = False
        db.collection("user_feedback").add(feedback)
        return {"message": "✅ 已記錄使用者回饋"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))