ASR_Models / app.py
aditagrawal's picture
Create app.py
dd1cc18 verified
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import uvicorn
# Initialize FastAPI
app = FastAPI()
# Define a request model
class ModelRequest(BaseModel):
model_id: str
tokenizer_id: str
# In-memory leaderboard to store results
leaderboard = []
def load_model_and_tokenizer(model_id: str, tokenizer_id: str):
"""Load the model and tokenizer from Hugging Face Hub."""
try:
model = AutoModelForSequenceClassification.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
return model, tokenizer
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
def evaluate_model(model, tokenizer):
"""Evaluate the model using a benchmark dataset."""
# Load a benchmark dataset (replace with your actual dataset)
dataset = load_dataset("glue", "mrpc") # Example: GLUE MRPC dataset
# Tokenize the inputs
inputs = tokenizer(dataset['test']['sentence1'], dataset['test']['sentence2'], padding=True, truncation=True, return_tensors="pt")
# Get predictions from the model
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits.argmax(dim=-1).numpy()
# Calculate accuracy
accuracy = accuracy_score(dataset['test']['label'], predictions)
return accuracy
def update_leaderboard(model_id: str, score: float):
"""Update the leaderboard with new results."""
leaderboard.append({"model_id": model_id, "score": score})
leaderboard.sort(key=lambda x: x['score'], reverse=True) # Sort by score descending
@app.post("/submit")
async def submit_model(request: ModelRequest):
"""Endpoint to submit a model for evaluation."""
model_id = request.model_id
tokenizer_id = request.tokenizer_id
model, tokenizer = load_model_and_tokenizer(model_id, tokenizer_id)
score = evaluate_model(model, tokenizer)
update_leaderboard(model_id, score)
return {"message": "Model evaluated successfully", "score": score}
@app.get("/leaderboard")
async def get_leaderboard():
"""Endpoint to retrieve the current leaderboard."""
return leaderboard
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)