Spaces:
Sleeping
Sleeping
| import os | |
| # ✅ Hugging Face 建議路徑(防止 cache 錯誤) | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers" | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["TORCH_HOME"] = "/tmp/torch" | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from datetime import datetime | |
| from typing import Optional, List | |
| from bert_explainer import analyze_text as bert_analyze_text | |
| from firebase_admin import credentials, firestore | |
| import firebase_admin | |
| import pytz | |
| import json | |
| import requests | |
| import torch | |
| # ✅ 初始化 FastAPI | |
| app = FastAPI( | |
| title="詐騙訊息辨識 API", | |
| description="使用 BERT 模型分析輸入文字是否為詐騙內容", | |
| version="1.0.0" | |
| ) | |
| # ✅ 跨域處理 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ✅ 掛載靜態檔案:支援 script.js / style.css | |
| app.mount("/static", StaticFiles(directory="."), name="static") | |
| # ✅ 回傳首頁 index.html | |
| async def serve_index(): | |
| return FileResponse("index.html") | |
| # ✅ Firebase 初始化 | |
| try: | |
| cred_data = os.getenv("FIREBASE_CREDENTIALS") | |
| if not cred_data: | |
| raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置") | |
| cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)}) | |
| firebase_admin.initialize_app(cred) | |
| db = firestore.client() | |
| except Exception as e: | |
| print(f"Firebase 初始化錯誤: {e}") | |
| # ✅ 從 Hugging Face Hub 載入模型(改為 /tmp) | |
| def load_model_from_hub(): | |
| model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth" | |
| model_path = "/tmp/model.pth" | |
| response = requests.get(model_url) | |
| if response.status_code == 200: | |
| with open(model_path, "wb") as f: | |
| f.write(response.content) | |
| return model_path | |
| raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth") | |
| model_path = "/tmp/model.pth" | |
| if not os.path.exists(model_path): | |
| model_path = load_model_from_hub() | |
| from AI_Model_architecture import BertLSTM_CNN_Classifier | |
| model = BertLSTM_CNN_Classifier() | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| # ✅ 資料格式 | |
| class TextAnalysisRequest(BaseModel): | |
| text: str | |
| user_id: Optional[str] = None | |
| class TextAnalysisResponse(BaseModel): | |
| status: str | |
| confidence: float | |
| suspicious_keywords: List[str] | |
| analysis_timestamp: datetime | |
| text_id: str | |
| # ✅ /predict API | |
| async def analyze_text_api(request: TextAnalysisRequest): | |
| try: | |
| tz = pytz.timezone("Asia/Taipei") | |
| now = datetime.now(tz) | |
| doc_id = now.strftime("%Y%m%dT%H%M%S") | |
| date_str = now.strftime("%Y-%m-%d %H:%M:%S") | |
| collection = now.strftime("%Y%m%d") | |
| result = bert_analyze_text(request.text) | |
| record = { | |
| "text_id": doc_id, | |
| "text": request.text, | |
| "user_id": request.user_id, | |
| "analysis_result": result, | |
| "timestamp": date_str, | |
| "type": "text_analysis" | |
| } | |
| db.collection(collection).document(doc_id).set(record) | |
| return TextAnalysisResponse( | |
| status=result["status"], | |
| confidence=result["confidence"], | |
| suspicious_keywords=result["suspicious_keywords"], | |
| analysis_timestamp=now, | |
| text_id=doc_id | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ✅ /feedback API | |
| async def save_user_feedback(feedback: dict): | |
| try: | |
| tz = pytz.timezone("Asia/Taipei") | |
| timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") | |
| feedback["timestamp"] = timestamp_str | |
| feedback["used_in_training"] = False | |
| db.collection("user_feedback").add(feedback) | |
| return {"message": "✅ 已記錄使用者回饋"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |