Spaces:
Runtime error
Runtime error
| # import torch | |
| # from fastapi import FastAPI, HTTPException | |
| # from pydantic import BaseModel | |
| # from transformers import AutoTokenizer | |
| # from torch.utils.data import DataLoader | |
| # from news_dataset import NewsDataset | |
| # from utils import load_model, predict_category | |
| # # Initialize FastAPI app | |
| # app = FastAPI() | |
| # # Load dataset and model | |
| # dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100) | |
| # num_classes = len(dataset.labels_dict) | |
| # model_path = './models/trained_model.pth' # Path to your trained model | |
| # model = load_model(model_path, num_classes) | |
| # labels_dict = dataset.labels_dict | |
| # # Tokenizer initialization | |
| # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| # # Define Pydantic model for input data | |
| # class RequestPost(BaseModel): | |
| # headline: str | |
| # article: str | |
| # @app.get("/") | |
| # def read_root(): | |
| # return {"Hello": "World"} | |
| # # Define endpoint for prediction | |
| # @app.post("/predict/") | |
| # def prediction(request: RequestPost): | |
| # try: | |
| # category, score = predict_category(request.headline, request.article, model, labels_dict) | |
| # return {"category": category, "score": score} | |
| # except Exception as e: | |
| # raise HTTPException(status_code=500, detail=str(e)) | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from transformers import AutoTokenizer | |
| from torch.utils.data import DataLoader | |
| from news_dataset import NewsDataset | |
| from utils import load_model, predict_category | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Load dataset and model | |
| dataset = NewsDataset(csv_file="./inshort_news_data.csv", max_length=100) | |
| num_classes = len(dataset.labels_dict) | |
| model_path = './models/trained_model1.pth' # Path to your trained model | |
| model = load_model(model_path, num_classes) | |
| labels_dict = dataset.labels_dict | |
| # Tokenizer initialization | |
| tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
| # Define Pydantic model for input data | |
| class RequestPost(BaseModel): | |
| headline: str | |
| article: str | |
| def read_root(): | |
| return {"Hello": "World"} | |
| # Define endpoint for prediction | |
| def prediction(request: RequestPost): | |
| try: | |
| category, score = predict_category(request.headline, request.article, model, labels_dict) | |
| return {"category": category, "score": score} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |