sameernotes commited on
Commit
efa927d
·
verified ·
1 Parent(s): a1840c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -127
app.py CHANGED
@@ -1,161 +1,136 @@
1
- from fastapi import FastAPI, HTTPException, Query
2
- from pydantic import BaseModel, Field, conlist
3
- import torch
4
- import torch.nn as nn
5
- import os # Import the 'os' module
6
  from typing import List
7
 
8
- # --- Model Definition (same as before) ---
9
- class NameGenderClassifierCNN(nn.Module):
10
- def __init__(self, vocab_size, embedding_dim, num_filters=64, filter_sizes=[2, 3, 4], dropout=0.5):
11
- super(NameGenderClassifierCNN, self).__init__()
 
12
 
13
- self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
14
 
15
- self.convs = nn.ModuleList([
16
- nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=fs)
17
- for fs in filter_sizes
18
- ])
19
 
20
- self.fc1 = nn.Linear(len(filter_sizes) * num_filters, 100)
21
- self.fc2 = nn.Linear(100, 1)
22
- self.dropout = nn.Dropout(dropout)
23
- self.sigmoid = nn.Sigmoid()
24
 
25
- def forward(self, x):
26
- x = self.embedding(x)
27
- x = x.transpose(1, 2)
28
- conv_outputs = []
29
- for conv in self.convs:
30
- conv_out = torch.relu(conv(x))
31
- pool_out = torch.max_pool1d(conv_out, conv_out.shape[2])
32
- conv_outputs.append(pool_out.squeeze(2))
33
- x = torch.cat(conv_outputs, dim=1)
34
- x = self.dropout(x)
35
- x = torch.relu(self.fc1(x))
36
- x = self.dropout(x)
37
- x = self.fc2(x)
38
- return self.sigmoid(x).squeeze()
39
 
 
 
 
 
40
 
41
 
42
- # --- Utility Function (same as before, but adapted) ---
 
 
 
43
 
44
- def tokenize_name(name, char_to_idx, max_length):
45
- """Tokenizes and pads a name."""
46
- name = str(name).lower()
47
- tokens = [char_to_idx.get(char, char_to_idx.get(' ', 1)) for char in name]
48
 
49
- # Pad or truncate
50
- if len(tokens) < max_length:
51
- tokens = tokens + [char_to_idx['<PAD>']] * (max_length - len(tokens))
52
- else:
53
- tokens = tokens[:max_length]
54
 
55
- return tokens
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # --- FastAPI Setup ---
 
 
 
 
59
 
60
- app = FastAPI(title="Indian Name Gender Prediction API",
61
- description="Predicts the gender of Indian names using a CNN model.",
62
- version="1.0")
 
 
 
63
 
64
- # --- Model Loading (on startup) ---
65
 
66
- MODEL_PATH = "models/indian_name_gender_model.pt" # Correct path within the space
 
 
 
67
 
 
 
68
 
69
- def load_model():
70
- """Loads the model, char_to_idx, and max_name_length."""
71
- try:
72
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
- checkpoint = torch.load(MODEL_PATH, map_location=device)
74
- char_to_idx = checkpoint['char_to_idx']
75
- max_name_length = checkpoint['max_name_length']
76
- config = checkpoint['model_config']
77
-
78
- model = NameGenderClassifierCNN(
79
- vocab_size=config['vocab_size'],
80
- embedding_dim=config['embedding_dim'],
81
- num_filters=config['num_filters'],
82
- filter_sizes=config['filter_sizes']
83
- )
84
- model.load_state_dict(checkpoint['model_state_dict'])
85
- model.to(device)
86
- model.eval() # Set to evaluation mode
87
- return model, char_to_idx, max_name_length, device
88
- except Exception as e:
89
- raise Exception(f"Error loading model: {e}")
90
 
91
- # Load model at startup
92
- try:
93
- model, char_to_idx, max_name_length, device = load_model()
94
- except Exception as e:
95
- print(f"Failed to load model: {e}")
96
- raise # Re-raise the exception to halt startup
97
 
98
- # --- Pydantic Models (for request/response validation) ---
 
99
 
100
- class PredictionRequest(BaseModel):
101
- names: conlist(str, min_length=1) = Field(..., example=["Aarav", "Anika"])
102
- threshold: float = Field(0.5, ge=0.0, le=1.0, description="Probability threshold for classifying as male.")
 
 
 
 
 
103
 
104
- class PredictionResponse(BaseModel):
105
- predictions: List[dict] = Field(..., example=[
106
- {"name": "Aarav", "predicted_gender": "Male", "male_probability": 0.95, "confidence": 0.95},
107
- {"name": "Anika", "predicted_gender": "Female", "male_probability": 0.05, "confidence": 0.95}
108
- ])
109
 
110
 
111
- # --- Prediction Function ---
112
 
113
- def predict_gender(name: str, model, char_to_idx, max_length, device, threshold: float = 0.5) -> tuple[str, float, float]:
114
- """Predicts gender for a single name. Includes threshold."""
115
- tokenized_name = tokenize_name(name, char_to_idx, max_length)
116
- input_tensor = torch.tensor([tokenized_name], dtype=torch.long).to(device)
117
 
118
- with torch.no_grad():
119
- output = model(input_tensor)
120
- probability = output.item()
121
- predicted_gender = 'Male' if probability >= threshold else 'Female'
122
- confidence = probability if probability >= threshold else 1 - probability
123
- return predicted_gender, probability, confidence
124
 
125
- # --- API Endpoints ---
126
 
127
- @app.get("/", response_model=str)
128
- async def read_root():
129
- return "Welcome to the Indian Name Gender Prediction API. Use the /predict endpoint."
130
 
131
- @app.post("/predict", response_model=PredictionResponse)
132
- async def predict(request: PredictionRequest):
133
- """Predicts the gender of one or more Indian names."""
 
 
 
 
 
 
 
 
134
  try:
135
- predictions = []
136
- for name in request.names:
137
- gender, prob, conf = predict_gender(name, model, char_to_idx, max_name_length, device, request.threshold)
138
- predictions.append({
139
- "name": name,
140
- "predicted_gender": gender,
141
- "male_probability": prob,
142
- "confidence": conf
143
- })
144
- return {"predictions": predictions}
145
  except Exception as e:
146
- raise HTTPException(status_code=500, detail=str(e))
147
-
148
- @app.get("/predict_single")
149
- async def predict_single(name: str = Query(..., description="The name to predict."),
150
- threshold: float = Query(0.5, ge=0.0, le=1.0, description="Probability threshold for classifying as male.")):
151
- """Predicts gender for a *single* name, provided as a query parameter."""
 
 
152
  try:
153
- gender, prob, conf = predict_gender(name, model, char_to_idx, max_name_length, device, threshold)
154
- return {
155
- "name": name,
156
- "predicted_gender": gender,
157
- "male_probability": prob,
158
- "confidence": conf
159
- }
160
  except Exception as e:
161
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ import os
 
 
 
 
2
  from typing import List
3
 
4
+ from fastapi import FastAPI, HTTPException, status
5
+ from pydantic import BaseModel
6
+ from google import genai
7
+ from google.genai import types
8
+ from google.protobuf.json_format import MessageToDict # For converting to dict
9
 
10
+ app = FastAPI()
11
 
12
+ # Load API key from environment variable
13
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
14
+ if not GEMINI_API_KEY:
15
+ raise ValueError("The GEMINI_API_KEY environment variable is not set.")
16
 
17
+ genai.configure(api_key=GEMINI_API_KEY)
18
+ client = genai.GenerativeModel('gemini-pro') # Use a consistent model. 'gemini-pro' is better for text.
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ class TranslationRequest(BaseModel):
22
+ text: str
23
+ target_language: str # Accept full language name, e.g., "Telugu", "Tamil", "Hindi"
24
+ source_language: str = None # Optional: User *might* provide the source.
25
 
26
 
27
+ class TranslationResponse(BaseModel):
28
+ translated_text: str
29
+ source_language: str # Always return the detected/used source language
30
+ target_language: str
31
 
 
 
 
 
32
 
 
 
 
 
 
33
 
34
+ # --- Helper Functions ---
35
 
36
+ def detect_language_and_options(text: str):
37
+ """Detects the language of the input text and provides translation options."""
38
+ contents = [
39
+ types.Content(
40
+ role="user",
41
+ parts=[types.Part.from_text(text=text)],
42
+ ),
43
+ types.Content(
44
+ role="model",
45
+ parts=[
46
+ types.Part.from_text(
47
+ text="""Please identify the language of the text provided and then offer translation options as numbered choices (1-5). Use this format: "The text is in [Language]. Choose a language to translate to: 1. [Option 1], 2. [Option 2], 3. [Option 3], 4. [Option 4], 5. [Option 5]"."""
48
+ )
49
+ ]
50
+ )
51
+ ]
52
 
53
+ response = client.generate_content(contents)
54
+ # Extract language and make options consistent. Robust parsing.
55
+ try:
56
+ response_text = response.text
57
+ source_language = response_text.split("The text is in ")[1].split(".")[0].strip()
58
 
59
+ options_str = response_text.split("Choose a language to translate to:")[1].strip()
60
+ options_list = [opt.split(". ")[1].strip() for opt in options_str.split(", ")]
61
+ # Ensure we have *exactly* 5 options, padding if needed.
62
+ while len(options_list) < 5:
63
+ options_list.append("Option Not Available") # Or some other placeholder
64
+ options_list = options_list[:5]
65
 
 
66
 
67
+ options = {
68
+ str(i + 1): lang for i, lang in enumerate(options_list)
69
+ }
70
+ return source_language, options
71
 
72
+ except (IndexError, AttributeError): # Handle parsing errors gracefully
73
+ raise HTTPException(status_code=500, detail="Error processing language detection response.")
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
76
 
77
+ def translate_with_gemini(text: str, source_language: str, target_language: str) -> str:
78
+ """Translates text using Gemini Pro, handling language codes correctly."""
79
 
80
+ # More direct prompting style. No few-shot examples needed for a simple translation task.
81
+ prompt = f"Translate the following text from {source_language} to {target_language}:\n\n{text}"
82
+ response = client.generate_content(prompt)
83
+
84
+ try:
85
+ return response.text
86
+ except (AttributeError, IndexError) as e:
87
+ raise HTTPException(status_code=500, detail=f"Error from Gemini API: {e}")
88
 
 
 
 
 
 
89
 
90
 
 
91
 
92
+ @app.post("/translate", response_model=TranslationResponse, status_code=status.HTTP_200_OK)
93
+ async def translate(request: TranslationRequest):
94
+ """Translates text from a source language to a target language."""
 
95
 
96
+ if not request.text:
97
+ raise HTTPException(status_code=400, detail="Text to translate cannot be empty.")
98
+ if not request.target_language:
99
+ raise HTTPException(status_code=400, detail="Target language must be provided.")
 
 
100
 
 
101
 
102
+ if request.source_language: # User provided source language. Use it directly.
103
+ source_language = request.source_language
 
104
 
105
+ else: # Detect the language
106
+ try:
107
+ source_language, _ = detect_language_and_options(request.text) # We don't need options here
108
+ except HTTPException as e: # Re-raise HTTP exceptions from the helper function.
109
+ raise e
110
+ except Exception as e: # Catch any other unexpected errors.
111
+ raise HTTPException(status_code=500, detail=f"Language detection failed: {e}")
112
+ # Validate the target language against a reasonable set of supported languages.
113
+ supported_languages = ["English", "Hindi", "Telugu", "Marathi", "Bengali", "Tamil", "Spanish", "French", "German", "Japanese", "Chinese"] # Add more as needed.
114
+ if request.target_language not in supported_languages:
115
+ raise HTTPException(status_code=400, detail=f"Target language '{request.target_language}' is not supported. Supported languages: {', '.join(supported_languages)}")
116
  try:
117
+ translated_text = translate_with_gemini(request.text, source_language, request.target_language)
118
+ return TranslationResponse(translated_text=translated_text, source_language=source_language, target_language=request.target_language)
119
+ except HTTPException as e:
120
+ raise e
 
 
 
 
 
 
121
  except Exception as e:
122
+ raise HTTPException(status_code=500, detail=f"Translation failed: {e}")
123
+
124
+
125
+ @app.post("/detect_language", status_code=status.HTTP_200_OK)
126
+ async def detect_language(text: str = ""): # Simpler input, just the text
127
+ """Detects the language of the input text and provides translation options."""
128
+ if not text:
129
+ raise HTTPException(status_code=400, detail="Text to detect cannot be empty.")
130
  try:
131
+ source_language, options = detect_language_and_options(text)
132
+ return {"source_language": source_language, "translation_options": options}
133
+ except HTTPException as e:
134
+ raise e
 
 
 
135
  except Exception as e:
136
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}")