kgemera commited on
Commit
467bc07
verified
1 Parent(s): afd52b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -16,29 +16,30 @@ model_path = hf_hub_download(
16
  # Cargar el modelo usando la ruta absoluta
17
  model = joblib.load(model_path)
18
 
19
- def predict_text(title, abstract):
20
- text = f"{title} {abstract}"
21
- pred = model.predict([text])[0]
22
- return pred
23
 
24
- def predict_file(file):
25
- df = pd.read_csv(file, sep=";")
26
- df["Prediction"] = model.predict(df["title"] + " " + df["abstract"])
27
- return df
28
 
 
29
  with gr.Blocks() as demo:
30
- gr.Markdown("# 馃┖ Medical Text Classifier")
31
-
32
- with gr.Tab("Texto individual"):
33
- title = gr.Textbox(label="T铆tulo")
34
- abstract = gr.Textbox(label="Abstract")
35
- output = gr.Textbox(label="Predicci贸n")
36
- btn = gr.Button("Clasificar")
37
- btn.click(predict_text, inputs=[title, abstract], outputs=output)
38
-
39
- with gr.Tab("Archivo CSV"):
40
- file_input = gr.File(label="Subir CSV", file_types=[".csv"])
41
- file_output = gr.Dataframe()
42
- file_input.change(predict_file, inputs=file_input, outputs=file_output)
43
-
44
- demo.launch()
 
 
 
 
16
  # Cargar el modelo usando la ruta absoluta
17
  model = joblib.load(model_path)
18
 
19
+ # --- Funciones ---
20
+ def predict_single(text):
21
+ return model.predict([text])[0]
 
22
 
23
+ def predict_batch(texts):
24
+ return [model.predict([t])[0] for t in texts]
 
 
25
 
26
+ # --- Blocks ---
27
  with gr.Blocks() as demo:
28
+ # Endpoint single prediction
29
+ single_input = gr.Textbox(label="Input text")
30
+ single_output = gr.Textbox(label="Predicted category")
31
+ single_btn = gr.Button("Predict")
32
+ single_btn.click(predict_single, inputs=single_input, outputs=single_output)
33
+
34
+ # Endpoint batch prediction
35
+ batch_input = gr.Textbox(label="Batch texts (comma separated)")
36
+ batch_output = gr.Textbox(label="Predictions (comma separated)")
37
+ batch_btn = gr.Button("Predict Batch")
38
+ batch_btn.click(
39
+ lambda x: predict_batch(x.split(",")),
40
+ inputs=batch_input,
41
+ outputs=batch_output
42
+ )
43
+
44
+ # --- Lanzar en API mode ---
45
+ demo.launch(server_name="0.0.0.0", server_port=7860, api_mode=True)