Ludovic Moncla commited on
Commit
4309fdf
·
1 Parent(s): 974cec8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
2
  from transformers import pipeline
3
  import geopy
4
  import plotly.graph_objects as go
 
5
 
 
6
 
7
- binary_classifier = pipeline("text-classification", model="GEODE/bert-base-multilingual-cased-binary-classifier-edda-coords")
8
- ner_pipeline = pipeline("token-classification", model="GEODE/camembert-base-edda-span-classification", aggregation_strategy="simple")
9
- generator = pipeline("text2text-generation", model="GEODE/mt5-small-coords-norm")
10
 
11
 
12
  def create_map(lat, long):
 
2
  from transformers import pipeline
3
  import geopy
4
  import plotly.graph_objects as go
5
+ import torch
6
 
7
+ device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
8
 
9
+ binary_classifier = pipeline("text-classification", model="GEODE/bert-base-multilingual-cased-binary-classifier-edda-coords", truncation=True, device=device)
10
+ ner_pipeline = pipeline("token-classification", model="GEODE/camembert-base-edda-span-classification", aggregation_strategy="simple", device=device)
11
+ generator = pipeline("text2text-generation", model="GEODE/mt5-small-coords-norm", truncation=True, device=device)
12
 
13
 
14
  def create_map(lat, long):