File size: 5,402 Bytes
9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 fb72482 f14d002 20248a7 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 20248a7 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 f14d002 9b0f5a2 20248a7 5505217 123c92b 9b0f5a2 123c92b a8d0de0 af2cfcf a8d0de0 ca3cd8a af2cfcf 20248a7 af2cfcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | """
app.py
------
FastAPI service exposing the Reddit upvotes prediction pipeline.
Endpoints
---------
POST /predict_single — single prediction via hours_ago and post_utc
POST /predict_batch — batch prediction via lists of hours_ago and post_utc
GET /health — liveness check
"""
import sys
import os
from pathlib import Path
# Fix sys.path so we can import the refactored code correctly
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
sys.path.append(str(PROJECT_ROOT))
sys.path.append(str(PROJECT_ROOT / "models"))
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, field_validator
import uvicorn
from typing import List, Dict, Any
from models.reddit_predict import predict_single, predict_batch
app = FastAPI(
title="Reddit Upvotes Predictor",
description="Predicts the number of upvotes a Reddit post will receive.",
version="1.0.0",
)
# Reference the artifacts directly in this specific 'test' directory as requested
LOCAL_ARTIFACTS_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth"
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class PostContent(BaseModel):
"""Shared post content fields, inherited by all request models."""
subreddit: str = Field(..., description="Subreddit name (e.g. 'MachineLearning')")
title: str = Field(..., min_length=1, description="Post title")
text: str = Field(..., min_length=1, description="Post body text")
class PredictSingleRequest(PostContent):
"""Single prediction."""
hours_ago: float
post_utc: float
class PredictBatchRequest(PostContent):
"""Batch prediction."""
hours_ago_list: List[float] = Field(..., min_length=1)
post_utc_list: List[float] = Field(..., min_length=1)
@field_validator("post_utc_list")
@classmethod
def validate_lengths(cls, v: List[float], info: Any) -> List[float]:
if 'hours_ago_list' in info.data and len(v) != len(info.data['hours_ago_list']):
raise ValueError("hours_ago_list and post_utc_list must have the same length.")
return v
class PredictSingleResponse(BaseModel):
prediction: float
class BatchPredictionItem(BaseModel):
hours_ago: float
post_utc: float
prediction: float
class PredictBatchResponse(BaseModel):
predictions: List[BatchPredictionItem]
count: int
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health", tags=["Meta"])
def health():
"""Liveness check."""
return {"status": "ok"}
@app.post(
"/predict_single",
response_model=PredictSingleResponse,
summary="Single prediction with mutual specific timestamp",
tags=["Predict"],
)
def api_predict_single(request: PredictSingleRequest):
"""
Predict upvotes for a single post context given both hours_ago and post_utc.
"""
try:
result = predict_single(
subreddit=request.subreddit,
title=request.title,
text=request.text,
hours_ago=request.hours_ago,
post_utc=request.post_utc,
model_path=LOCAL_ARTIFACTS_MODEL_PATH
)
except (ValueError, FileNotFoundError) as e:
raise HTTPException(status_code=422, detail=str(e))
return PredictSingleResponse(prediction=result)
@app.post(
"/predict_batch",
response_model=PredictBatchResponse,
summary="Batch prediction with matching length lists",
tags=["Predict"],
)
def api_predict_batch(request: PredictBatchRequest):
"""
Predict upvotes for an array of potential posting times.
"""
try:
results = predict_batch(
subreddit=request.subreddit,
title=request.title,
text=request.text,
hours_ago_list=request.hours_ago_list,
post_utc_list=request.post_utc_list,
model_path=LOCAL_ARTIFACTS_MODEL_PATH
)
except (ValueError, FileNotFoundError) as e:
raise HTTPException(status_code=422, detail=str(e))
return PredictBatchResponse(predictions=results, count=len(results))
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# git add <filename>
# git commit -m <message>
# git push
# git remote -v ???
# git config --global credential.helper store
# uvicorn app:app --reload
# git commit conventions?
# https://cyrilfrl-test.hf.space/greet
# git commit conventions
# curl -X POST https://cyrilfrl-test.hf.space/greet -H "Content-Type: application/json" -d '{"name": "Cyril"}'
# (dev) C:\Users\Cyril\Desktop\B3Q2\startup\app\HuggingFace\RedditPredictUpvotes\test>curl -X POST https://cyrilfrl-test.hf.space/greet -H "Content-Type: application/json" -d "{\"name\": \"Cyril\"}"
# {"message":"Hello Cyril!!"}
# (dev) C:\Users\Cyril\Desktop\B3Q2\startup\app\HuggingFace\RedditPredictUpvotes\test>
# curl -X POST "http://127.0.0.1:8000/greet" -H "Content-Type: application/json" -d "{\"name\": \"Cyril\"}"
|