subbunanepalli commited on
Commit
d133e6d
·
verified ·
1 Parent(s): e489e8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -14,10 +14,10 @@ from sklearn.pipeline import Pipeline
14
  # ========== Config ==========
15
  DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
16
  MODEL_DIR = "models"
17
- MODEL_PATH = os.path.join(MODEL_DIR, "logreg_pipeline.pkl")
18
 
19
  # ========== FastAPI Init ==========
20
- app = FastAPI(title="TFIDF Logistic Regression Classifier")
21
 
22
  # ========== Input Schema ==========
23
  class TransactionData(BaseModel):
@@ -132,38 +132,46 @@ def create_text_input(row):
132
  Beneficial Owner: {row['Beneficial_Owner']}
133
  """
134
 
135
- # ========== Endpoints ==========
 
 
 
136
 
 
137
  @app.post("/train")
138
  def train_model():
139
- df = pd.read_csv(DATA_PATH).fillna("")
 
140
  df["text_input"] = df.apply(create_text_input, axis=1)
141
 
142
  X = df["text_input"]
143
- y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome"]]
144
 
145
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
146
 
 
 
147
  pipeline = Pipeline([
148
- ("vectorizer", TfidfVectorizer()),
149
- ("classifier", MultiOutputClassifier(LogisticRegression(max_iter=1000)))
150
  ])
151
 
152
  pipeline.fit(X_train, y_train)
 
153
  os.makedirs(MODEL_DIR, exist_ok=True)
154
  joblib.dump(pipeline, MODEL_PATH)
155
 
156
  accuracy = pipeline.score(X_test, y_test)
157
- return {"message": "Model trained successfully.", "accuracy": accuracy}
158
 
159
  @app.post("/predict")
160
  def predict(request: TransactionData):
161
  try:
162
  model = joblib.load(MODEL_PATH)
163
- input_df = pd.DataFrame([request.dict()]).fillna("")
164
- text_input = create_text_input(input_df.iloc[0])
 
165
  prediction = model.predict([text_input])[0]
166
-
167
  return {
168
  "Maker_Action": prediction[0],
169
  "Escalation_Level": prediction[1],
@@ -179,5 +187,6 @@ def validate_input(request: TransactionData):
179
  return {"message": "Input is valid."}
180
 
181
  @app.get("/test")
182
- def test():
183
- return {"message": "API is working."}
 
 
14
  # ========== Config ==========
15
  DATA_PATH = "data/synthetic_transactions_samples_5000.csv"
16
  MODEL_DIR = "models"
17
+ MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
18
 
19
  # ========== FastAPI Init ==========
20
+ app = FastAPI()
21
 
22
  # ========== Input Schema ==========
23
  class TransactionData(BaseModel):
 
132
  Beneficial Owner: {row['Beneficial_Owner']}
133
  """
134
 
135
+ # ========== Root ==========
136
+ @app.get("/")
137
+ def root():
138
+ return {"message": "TF-IDF Logistic Regression API is running."}
139
 
140
+ # ========== API Routes ==========
141
  @app.post("/train")
142
  def train_model():
143
+ df = pd.read_csv(DATA_PATH)
144
+ df = df.fillna("")
145
  df["text_input"] = df.apply(create_text_input, axis=1)
146
 
147
  X = df["text_input"]
148
+ y = df[["Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome", "Red_Flag_Reason"]]
149
 
150
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
151
 
152
+ vectorizer = TfidfVectorizer()
153
+ classifier = MultiOutputClassifier(LogisticRegression(max_iter=1000))
154
  pipeline = Pipeline([
155
+ ("vectorizer", vectorizer),
156
+ ("classifier", classifier)
157
  ])
158
 
159
  pipeline.fit(X_train, y_train)
160
+
161
  os.makedirs(MODEL_DIR, exist_ok=True)
162
  joblib.dump(pipeline, MODEL_PATH)
163
 
164
  accuracy = pipeline.score(X_test, y_test)
165
+ return {"message": "Model trained and saved.", "accuracy": accuracy}
166
 
167
  @app.post("/predict")
168
  def predict(request: TransactionData):
169
  try:
170
  model = joblib.load(MODEL_PATH)
171
+ input_data = pd.DataFrame([request.dict()])
172
+ input_data = input_data.fillna("")
173
+ text_input = create_text_input(input_data.iloc[0])
174
  prediction = model.predict([text_input])[0]
 
175
  return {
176
  "Maker_Action": prediction[0],
177
  "Escalation_Level": prediction[1],
 
187
  return {"message": "Input is valid."}
188
 
189
  @app.get("/test")
190
+ def test_api():
191
+ return {"message": "Test successful."}
192
+