chalana2001's picture
Update app.py
bbf26c0 verified
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel):
prompt: str
# Path to model folder inside the Space
MODEL_PATH = "./"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
@app.get("/")
async def health_check():
return {"status": "healthy", "message": "API is running"}
@app.post("/predict")
async def predict(request: PromptRequest):
inputs = tokenizer(request.prompt, return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(**inputs, max_new_tokens=256)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"result": result}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)