Spaces:
Sleeping
Sleeping
File size: 6,081 Bytes
12fbfda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel, Field, conlist
import torch
import torch.nn as nn
import os # Import the 'os' module
from typing import List
# --- Model Definition (same as before) ---
class NameGenderClassifierCNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_filters=64, filter_sizes=[2, 3, 4], dropout=0.5):
super(NameGenderClassifierCNN, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.convs = nn.ModuleList([
nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=fs)
for fs in filter_sizes
])
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, 100)
self.fc2 = nn.Linear(100, 1)
self.dropout = nn.Dropout(dropout)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.embedding(x)
x = x.transpose(1, 2)
conv_outputs = []
for conv in self.convs:
conv_out = torch.relu(conv(x))
pool_out = torch.max_pool1d(conv_out, conv_out.shape[2])
conv_outputs.append(pool_out.squeeze(2))
x = torch.cat(conv_outputs, dim=1)
x = self.dropout(x)
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return self.sigmoid(x).squeeze()
# --- Utility Function (same as before, but adapted) ---
def tokenize_name(name, char_to_idx, max_length):
"""Tokenizes and pads a name."""
name = str(name).lower()
tokens = [char_to_idx.get(char, char_to_idx.get(' ', 1)) for char in name]
# Pad or truncate
if len(tokens) < max_length:
tokens = tokens + [char_to_idx['<PAD>']] * (max_length - len(tokens))
else:
tokens = tokens[:max_length]
return tokens
# --- FastAPI Setup ---
app = FastAPI(title="Indian Name Gender Prediction API",
description="Predicts the gender of Indian names using a CNN model.",
version="1.0")
# --- Model Loading (on startup) ---
MODEL_PATH = "models/indian_name_gender_model.pt" # Correct path within the space
def load_model():
"""Loads the model, char_to_idx, and max_name_length."""
try:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(MODEL_PATH, map_location=device)
char_to_idx = checkpoint['char_to_idx']
max_name_length = checkpoint['max_name_length']
config = checkpoint['model_config']
model = NameGenderClassifierCNN(
vocab_size=config['vocab_size'],
embedding_dim=config['embedding_dim'],
num_filters=config['num_filters'],
filter_sizes=config['filter_sizes']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval() # Set to evaluation mode
return model, char_to_idx, max_name_length, device
except Exception as e:
raise Exception(f"Error loading model: {e}")
# Load model at startup
try:
model, char_to_idx, max_name_length, device = load_model()
except Exception as e:
print(f"Failed to load model: {e}")
raise # Re-raise the exception to halt startup
# --- Pydantic Models (for request/response validation) ---
class PredictionRequest(BaseModel):
names: conlist(str, min_length=1) = Field(..., example=["Aarav", "Anika"])
threshold: float = Field(0.5, ge=0.0, le=1.0, description="Probability threshold for classifying as male.")
class PredictionResponse(BaseModel):
predictions: List[dict] = Field(..., example=[
{"name": "Aarav", "predicted_gender": "Male", "male_probability": 0.95, "confidence": 0.95},
{"name": "Anika", "predicted_gender": "Female", "male_probability": 0.05, "confidence": 0.95}
])
# --- Prediction Function ---
def predict_gender(name: str, model, char_to_idx, max_length, device, threshold: float = 0.5) -> tuple[str, float, float]:
"""Predicts gender for a single name. Includes threshold."""
tokenized_name = tokenize_name(name, char_to_idx, max_length)
input_tensor = torch.tensor([tokenized_name], dtype=torch.long).to(device)
with torch.no_grad():
output = model(input_tensor)
probability = output.item()
predicted_gender = 'Male' if probability >= threshold else 'Female'
confidence = probability if probability >= threshold else 1 - probability
return predicted_gender, probability, confidence
# --- API Endpoints ---
@app.get("/", response_model=str)
async def read_root():
return "Welcome to the Indian Name Gender Prediction API. Use the /predict endpoint."
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Predicts the gender of one or more Indian names."""
try:
predictions = []
for name in request.names:
gender, prob, conf = predict_gender(name, model, char_to_idx, max_name_length, device, request.threshold)
predictions.append({
"name": name,
"predicted_gender": gender,
"male_probability": prob,
"confidence": conf
})
return {"predictions": predictions}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/predict_single")
async def predict_single(name: str = Query(..., description="The name to predict."),
threshold: float = Query(0.5, ge=0.0, le=1.0, description="Probability threshold for classifying as male.")):
"""Predicts gender for a *single* name, provided as a query parameter."""
try:
gender, prob, conf = predict_gender(name, model, char_to_idx, max_name_length, device, threshold)
return {
"name": name,
"predicted_gender": gender,
"male_probability": prob,
"confidence": conf
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |