Ludovic Moncla commited on
Commit
5287b31
·
1 Parent(s): 4a51b31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -1,15 +1,24 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
 
4
- # Load your fine-tuned CamemBERT NER model
5
- model_name = "GEODE/camembert-base-edda-span-classification"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForTokenClassification.from_pretrained(model_name)
8
 
9
 
10
- ner_pipeline = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def extract_coordinates(text):
 
 
13
  # Run NER
14
  entities = ner_pipeline(text)
15
 
@@ -21,11 +30,11 @@ def extract_coordinates(text):
21
  else:
22
  return "No coordinates found"
23
 
 
24
 
25
  def norm_coordinates(text):
26
-
27
- generator = pipeline("text2text-generation", model="GEODE/mt5-small-coords-norm")
28
-
29
  # Example input text
30
  input_text = "extract_coordinates: " + text
31
 
@@ -60,6 +69,9 @@ with gr.Blocks() as demo:
60
 
61
 
62
  with gr.Column():
 
 
 
63
  out_text = gr.Textbox(label="Extracted coordinates (fine-tuned CamemBERT NER)")
64
  run_btn.click(fn=extract_coordinates, inputs=inp, outputs=out_text)
65
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
 
 
 
 
 
4
 
5
 
6
+ binary_classifier = pipeline("text-classification", model="GEODE/bert-base-multilingual-cased-binary-classifier-edda-coords")
7
+ ner_pipeline = pipeline("token-classification", model="GEODE/camembert-base-edda-span-classification", aggregation_strategy="simple")
8
+ generator = pipeline("text2text-generation", model="GEODE/mt5-small-coords-norm")
9
+
10
+ def detect_coordinates(text):
11
+ # Run binary classification
12
+ result = binary_classifier(text)
13
+
14
+ if result[0]['label'] == 'Positive':
15
+ return "Coordinates found"
16
+ else:
17
+ return "No coordinates found"
18
 
19
  def extract_coordinates(text):
20
+ if detect_coordinates(text) == "No coordinates found":
21
+ return "No coordinates found"
22
  # Run NER
23
  entities = ner_pipeline(text)
24
 
 
30
  else:
31
  return "No coordinates found"
32
 
33
+ # bert-base-multilingual-cased-binary-classifier-edda-coords
34
 
35
  def norm_coordinates(text):
36
+ if detect_coordinates(text) == "No coordinates found":
37
+ return "No coordinates found"
 
38
  # Example input text
39
  input_text = "extract_coordinates: " + text
40
 
 
69
 
70
 
71
  with gr.Column():
72
+ out_text = gr.Textbox(label="Extracted coordinates (fine-tuned CamemBERT NER)")
73
+ run_btn.click(fn=detect_coordinates, inputs=inp, outputs=out_text)
74
+
75
  out_text = gr.Textbox(label="Extracted coordinates (fine-tuned CamemBERT NER)")
76
  run_btn.click(fn=extract_coordinates, inputs=inp, outputs=out_text)
77