Spaces:
Running
on
Zero
Running
on
Zero
Lord-Raven
commited on
Commit
·
50b814c
1
Parent(s):
09439d2
Experimenting with few-shot classification.
Browse files- app.py +40 -2
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers import pipeline
|
|
| 6 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
| 7 |
from fastapi import FastAPI
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 9 |
|
| 10 |
# CORS Config
|
| 11 |
app = FastAPI()
|
|
@@ -18,6 +19,26 @@ app.add_middleware(
|
|
| 18 |
allow_headers=["*"],
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
| 22 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
| 23 |
# "Xenova/bart-large-mnli" A bit slow
|
|
@@ -29,22 +50,39 @@ tokenizer_name = "cross-encoder/nli-deberta-v3-small"
|
|
| 29 |
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
|
| 30 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# file = cached_download("https://huggingface.co/" + model_name + "")
|
| 33 |
# sess = InferenceSession(file)
|
| 34 |
|
| 35 |
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
| 36 |
|
| 37 |
-
def
|
| 38 |
if request:
|
| 39 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
| 40 |
return "{}"
|
| 41 |
data = json.loads(data_string)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
|
| 43 |
response_string = json.dumps(results)
|
| 44 |
return response_string
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
gradio_interface = gradio.Interface(
|
| 47 |
-
fn =
|
| 48 |
inputs = gradio.Textbox(label="JSON Input"),
|
| 49 |
outputs = gradio.Textbox()
|
| 50 |
)
|
|
|
|
| 6 |
from optimum.onnxruntime import ORTModelForSequenceClassification
|
| 7 |
from fastapi import FastAPI
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from setfit import SetFitModel
|
| 10 |
|
| 11 |
# CORS Config
|
| 12 |
app = FastAPI()
|
|
|
|
| 19 |
allow_headers=["*"],
|
| 20 |
)
|
| 21 |
|
| 22 |
+
class OnnxSetFitModel:
|
| 23 |
+
def __init__(self, ort_model, tokenizer, model_head):
|
| 24 |
+
self.ort_model = ort_model
|
| 25 |
+
self.tokenizer = tokenizer
|
| 26 |
+
self.model_head = model_head
|
| 27 |
+
|
| 28 |
+
def predict(self, inputs):
|
| 29 |
+
encoded_inputs = self.tokenizer(
|
| 30 |
+
inputs, padding=True, truncation=True, return_tensors="pt"
|
| 31 |
+
).to(self.ort_model.device)
|
| 32 |
+
|
| 33 |
+
outputs = self.ort_model(**encoded_inputs)
|
| 34 |
+
embeddings = mean_pooling(
|
| 35 |
+
outputs["last_hidden_state"], encoded_inputs["attention_mask"]
|
| 36 |
+
)
|
| 37 |
+
return self.model_head.predict(embeddings.cpu())
|
| 38 |
+
|
| 39 |
+
def __call__(self, inputs):
|
| 40 |
+
return self.predict(inputs)
|
| 41 |
+
|
| 42 |
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
|
| 43 |
# "xenova/deberta-v3-base-tasksource-nli" Not impressed
|
| 44 |
# "Xenova/bart-large-mnli" A bit slow
|
|
|
|
| 50 |
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
|
| 51 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
|
| 52 |
|
| 53 |
+
few_shot_model_name = "moshew/bge-small-en-v1.5_setfit-sst2-english"
|
| 54 |
+
few_shot_model = setFitModel.from_pretrained(few_shot_model_name)
|
| 55 |
+
few_shot_tokenizer = AutoTokenizer.from_pretrained('bge_auto_opt_04', model_max_length=512)
|
| 56 |
+
ort_model = ORTModelForFeatureExtraction.from_pretrained('bge_auto_opt_O4')
|
| 57 |
+
onnx_few_shot_model = OnnxSetFitModel(ort_model, tokenizer, model.model_head)
|
| 58 |
+
|
| 59 |
# file = cached_download("https://huggingface.co/" + model_name + "")
|
| 60 |
# sess = InferenceSession(file)
|
| 61 |
|
| 62 |
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
|
| 63 |
|
| 64 |
+
def classify(data_string, request: gradio.Request):
|
| 65 |
if request:
|
| 66 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
| 67 |
return "{}"
|
| 68 |
data = json.loads(data_string)
|
| 69 |
+
if (data['task'] == 'few_shot_classification')
|
| 70 |
+
return few_shot_classification
|
| 71 |
+
else
|
| 72 |
+
return zero_shot_classification
|
| 73 |
+
|
| 74 |
+
def zero_shot_classification(data):
|
| 75 |
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
|
| 76 |
response_string = json.dumps(results)
|
| 77 |
return response_string
|
| 78 |
|
| 79 |
+
def few_shot_classification(data):
|
| 80 |
+
results = onnx_few_shot_model(data['sequence'])
|
| 81 |
+
response_string = json.dumps(results)
|
| 82 |
+
return response_string
|
| 83 |
+
|
| 84 |
gradio_interface = gradio.Interface(
|
| 85 |
+
fn = classify,
|
| 86 |
inputs = gradio.Textbox(label="JSON Input"),
|
| 87 |
outputs = gradio.Textbox()
|
| 88 |
)
|
requirements.txt
CHANGED
|
@@ -2,5 +2,6 @@ fastapi==0.88.0
|
|
| 2 |
json5==0.9.10
|
| 3 |
numpy==1.23.4
|
| 4 |
optimum[exporters,onnxruntime]==1.21.3
|
|
|
|
| 5 |
torch==1.12.1
|
| 6 |
torchvision==0.13.1
|
|
|
|
| 2 |
json5==0.9.10
|
| 3 |
numpy==1.23.4
|
| 4 |
optimum[exporters,onnxruntime]==1.21.3
|
| 5 |
+
setfit==1.0.3
|
| 6 |
torch==1.12.1
|
| 7 |
torchvision==0.13.1
|