Spaces:
Runtime error
Runtime error
paul hilders commited on
Commit ·
cca85c2
1
Parent(s): d80767e
Import spacy model
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import torch
|
|
| 8 |
import CLIP.clip as clip
|
| 9 |
|
| 10 |
import spacy
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
from clip_grounding.utils.image import pad_to_square
|
|
@@ -24,7 +25,7 @@ clip.clip._MODELS = {
|
|
| 24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
# Gradio Section:
|
| 30 |
def run_demo(image, text):
|
|
@@ -43,6 +44,16 @@ def run_demo(image, text):
|
|
| 43 |
for i, token in enumerate(text_tokens_decoded):
|
| 44 |
highlighted_text.append((str(token), float(text_scores[i])))
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
return overlapped, highlighted_text
|
| 47 |
|
| 48 |
input_img = gr.inputs.Image(type='pil', label="Original Image")
|
|
|
|
| 8 |
import CLIP.clip as clip
|
| 9 |
|
| 10 |
import spacy
|
| 11 |
+
from spacy import displacy
|
| 12 |
|
| 13 |
|
| 14 |
from clip_grounding.utils.image import pad_to_square
|
|
|
|
| 25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
| 27 |
|
| 28 |
+
nlp = spacy.load("en_core_web_sm")
|
| 29 |
|
| 30 |
# Gradio Section:
|
| 31 |
def run_demo(image, text):
|
|
|
|
| 44 |
for i, token in enumerate(text_tokens_decoded):
|
| 45 |
highlighted_text.append((str(token), float(text_scores[i])))
|
| 46 |
|
| 47 |
+
# Apply NER to extract named entities, and run the explainability method
|
| 48 |
+
# for each named entity.
|
| 49 |
+
highlighed_entities = []
|
| 50 |
+
for ent in nlp(text).ents:
|
| 51 |
+
ent_text = ent.text
|
| 52 |
+
ent_label = ent.label_
|
| 53 |
+
highlighed_entities.append((ent_text, ent_label))
|
| 54 |
+
|
| 55 |
+
print(highlighed_entities)
|
| 56 |
+
|
| 57 |
return overlapped, highlighted_text
|
| 58 |
|
| 59 |
input_img = gr.inputs.Image(type='pil', label="Original Image")
|