Spaces:
Running
on
Zero
Running
on
Zero
Lord-Raven
commited on
Commit
·
fbcdba4
1
Parent(s):
bd9a53f
Experimenting with few-shot classification.
Browse files
app.py
CHANGED
|
@@ -7,8 +7,10 @@ from optimum.onnxruntime import ORTModelForSequenceClassification
|
|
| 7 |
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
| 8 |
from fastapi import FastAPI
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
-
from setfit import SetFitModel
|
| 11 |
from setfit.exporters.utils import mean_pooling
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# CORS Config
|
| 14 |
app = FastAPI()
|
|
@@ -55,9 +57,36 @@ classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=to
|
|
| 55 |
|
| 56 |
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
|
| 57 |
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
|
| 58 |
-
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
| 60 |
|
|
|
|
|
|
|
| 61 |
def classify(data_string, request: gradio.Request):
|
| 62 |
if request:
|
| 63 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
|
@@ -75,6 +104,7 @@ def zero_shot_classification(data):
|
|
| 75 |
|
| 76 |
def few_shot_classification(data):
|
| 77 |
results = onnx_few_shot_model(data['sequence'])
|
|
|
|
| 78 |
response_string = json.dumps(results.tolist())
|
| 79 |
return response_string
|
| 80 |
|
|
|
|
| 7 |
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
| 8 |
from fastapi import FastAPI
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from setfit import SetFitModel, Trainer, TrainingArguments
|
| 11 |
from setfit.exporters.utils import mean_pooling
|
| 12 |
+
from setfit import get_templated_dataset
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
|
| 15 |
# CORS Config
|
| 16 |
app = FastAPI()
|
|
|
|
| 57 |
|
| 58 |
few_shot_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5', model_max_length=512)
|
| 59 |
ort_model = ORTModelForFeatureExtraction.from_pretrained('BAAI/bge-small-en-v1.5', file_name="onnx/model.onnx")
|
| 60 |
+
few_shot_model = SetFitModel.from_pretrained("moshew/bge-small-en-v1.5_setfit-sst2-english", multi_target_strategy="multi-output")
|
| 61 |
+
|
| 62 |
+
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
|
| 63 |
+
print(test_dataset)
|
| 64 |
+
classes = test_dataset.features["label"].names
|
| 65 |
+
print(classes)
|
| 66 |
+
train_dataset = get_templated_dataset()
|
| 67 |
+
print(train_dataset)
|
| 68 |
+
print(train_dataset[0])
|
| 69 |
+
|
| 70 |
+
args = TrainingArguments(
|
| 71 |
+
batch_size=32,
|
| 72 |
+
num_epochs=1
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
trainer = Trainer(
|
| 76 |
+
model=few_shot_model,
|
| 77 |
+
args=args,
|
| 78 |
+
train_dataset=train_dataset,
|
| 79 |
+
eval_dataset=test_dataset
|
| 80 |
+
)
|
| 81 |
+
trainer.train()
|
| 82 |
+
|
| 83 |
+
metrics = trainer.evaluate()
|
| 84 |
+
print(metrics)
|
| 85 |
+
|
| 86 |
onnx_few_shot_model = OnnxSetFitModel(ort_model, few_shot_tokenizer, few_shot_model.model_head)
|
| 87 |
|
| 88 |
+
|
| 89 |
+
|
| 90 |
def classify(data_string, request: gradio.Request):
|
| 91 |
if request:
|
| 92 |
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space"]:
|
|
|
|
| 104 |
|
| 105 |
def few_shot_classification(data):
|
| 106 |
results = onnx_few_shot_model(data['sequence'])
|
| 107 |
+
print([classes[idx] for idx in results)
|
| 108 |
response_string = json.dumps(results.tolist())
|
| 109 |
return response_string
|
| 110 |
|