test / app.py
cyrilfrl's picture
new model
f14d002
"""
app.py
------
FastAPI service exposing the Reddit upvotes prediction pipeline.
Endpoints
---------
POST /predict_single — single prediction via hours_ago and post_utc
POST /predict_batch — batch prediction via lists of hours_ago and post_utc
GET /health — liveness check
"""
import sys
import os
from pathlib import Path
# Fix sys.path so we can import the refactored code correctly
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent
sys.path.append(str(PROJECT_ROOT))
sys.path.append(str(PROJECT_ROOT / "models"))
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, field_validator
import uvicorn
from typing import List, Dict, Any
from models.reddit_predict import predict_single, predict_batch
app = FastAPI(
title="Reddit Upvotes Predictor",
description="Predicts the number of upvotes a Reddit post will receive.",
version="1.0.0",
)
# Reference the artifacts directly in this specific 'test' directory as requested
LOCAL_ARTIFACTS_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth"
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class PostContent(BaseModel):
"""Shared post content fields, inherited by all request models."""
subreddit: str = Field(..., description="Subreddit name (e.g. 'MachineLearning')")
title: str = Field(..., min_length=1, description="Post title")
text: str = Field(..., min_length=1, description="Post body text")
class PredictSingleRequest(PostContent):
"""Single prediction."""
hours_ago: float
post_utc: float
class PredictBatchRequest(PostContent):
"""Batch prediction."""
hours_ago_list: List[float] = Field(..., min_length=1)
post_utc_list: List[float] = Field(..., min_length=1)
@field_validator("post_utc_list")
@classmethod
def validate_lengths(cls, v: List[float], info: Any) -> List[float]:
if 'hours_ago_list' in info.data and len(v) != len(info.data['hours_ago_list']):
raise ValueError("hours_ago_list and post_utc_list must have the same length.")
return v
class PredictSingleResponse(BaseModel):
prediction: float
class BatchPredictionItem(BaseModel):
hours_ago: float
post_utc: float
prediction: float
class PredictBatchResponse(BaseModel):
predictions: List[BatchPredictionItem]
count: int
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health", tags=["Meta"])
def health():
"""Liveness check."""
return {"status": "ok"}
@app.post(
"/predict_single",
response_model=PredictSingleResponse,
summary="Single prediction with mutual specific timestamp",
tags=["Predict"],
)
def api_predict_single(request: PredictSingleRequest):
"""
Predict upvotes for a single post context given both hours_ago and post_utc.
"""
try:
result = predict_single(
subreddit=request.subreddit,
title=request.title,
text=request.text,
hours_ago=request.hours_ago,
post_utc=request.post_utc,
model_path=LOCAL_ARTIFACTS_MODEL_PATH
)
except (ValueError, FileNotFoundError) as e:
raise HTTPException(status_code=422, detail=str(e))
return PredictSingleResponse(prediction=result)
@app.post(
"/predict_batch",
response_model=PredictBatchResponse,
summary="Batch prediction with matching length lists",
tags=["Predict"],
)
def api_predict_batch(request: PredictBatchRequest):
"""
Predict upvotes for an array of potential posting times.
"""
try:
results = predict_batch(
subreddit=request.subreddit,
title=request.title,
text=request.text,
hours_ago_list=request.hours_ago_list,
post_utc_list=request.post_utc_list,
model_path=LOCAL_ARTIFACTS_MODEL_PATH
)
except (ValueError, FileNotFoundError) as e:
raise HTTPException(status_code=422, detail=str(e))
return PredictBatchResponse(predictions=results, count=len(results))
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# git add <filename>
# git commit -m <message>
# git push
# git remote -v ???
# git config --global credential.helper store
# uvicorn app:app --reload
# git commit conventions?
# https://cyrilfrl-test.hf.space/greet
# git commit conventions
# curl -X POST https://cyrilfrl-test.hf.space/greet -H "Content-Type: application/json" -d '{"name": "Cyril"}'
# (dev) C:\Users\Cyril\Desktop\B3Q2\startup\app\HuggingFace\RedditPredictUpvotes\test>curl -X POST https://cyrilfrl-test.hf.space/greet -H "Content-Type: application/json" -d "{\"name\": \"Cyril\"}"
# {"message":"Hello Cyril!!"}
# (dev) C:\Users\Cyril\Desktop\B3Q2\startup\app\HuggingFace\RedditPredictUpvotes\test>
# curl -X POST "http://127.0.0.1:8000/greet" -H "Content-Type: application/json" -d "{\"name\": \"Cyril\"}"