import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class PromptRequest(BaseModel): prompt: str # Path to model folder inside the Space MODEL_PATH = "./" # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH) @app.get("/") async def health_check(): return {"status": "healthy", "message": "API is running"} @app.post("/predict") async def predict(request: PromptRequest): inputs = tokenizer(request.prompt, return_tensors="pt", truncation=True, padding=True) outputs = model.generate(**inputs, max_new_tokens=256) result = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"result": result} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)