import os from fastapi import FastAPI from pydantic import BaseModel from typing import Optional from transformers import pipeline MODEL_ID = os.getenv("MODEL_ID", "distilgpt2") CACHE_DIR = os.getenv("HF_HOME", "/app/.cache") os.makedirs(CACHE_DIR, exist_ok=True) app = FastAPI(title="FastAPI Hugging Face Space") generator = pipeline("text-generation", model=MODEL_ID, cache_dir=CACHE_DIR) class GenerateRequest(BaseModel): prompt: str max_length: Optional[int] = 64 @app.get("/") async def root(): return {"message": "API running. Use POST /generate to generate text."} @app.get("/health") async def health(): return {"status": "ok", "model": MODEL_ID, "cache": CACHE_DIR} @app.post("/generate") async def generate(req: GenerateRequest): result = generator(req.prompt, max_length=req.max_length, num_return_sequences=1) return {"generated_text": result[0]["generated_text"]} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)