Spaces:
Sleeping
Sleeping
added labels as in input
Browse files
main.py
CHANGED
|
@@ -5,6 +5,7 @@ import torch
|
|
| 5 |
from detoxify import Detoxify
|
| 6 |
import asyncio
|
| 7 |
from fastapi.concurrency import run_in_threadpool
|
|
|
|
| 8 |
|
| 9 |
class Guardrail:
|
| 10 |
def __init__(self):
|
|
@@ -60,17 +61,20 @@ class TopicBannerClassifier:
|
|
| 60 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
)
|
| 62 |
self.hypothesis_template = "This text is about {}"
|
| 63 |
-
self.classes_verbalized = ["politics", "economy", "entertainment", "environment"]
|
| 64 |
|
| 65 |
-
async def classify(self, text):
|
| 66 |
return await run_in_threadpool(
|
| 67 |
self.classifier,
|
| 68 |
text,
|
| 69 |
-
|
| 70 |
hypothesis_template=self.hypothesis_template,
|
| 71 |
multi_label=False
|
| 72 |
)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
class TopicBannerResult(BaseModel):
|
| 75 |
sequence: str
|
| 76 |
labels: list
|
|
@@ -108,9 +112,9 @@ async def classify_text(text_prompt: TextPrompt):
|
|
| 108 |
raise HTTPException(status_code=500, detail=str(e))
|
| 109 |
|
| 110 |
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
|
| 111 |
-
async def classify_topic_banner(
|
| 112 |
try:
|
| 113 |
-
result = await topic_banner_classifier.classify(
|
| 114 |
return {
|
| 115 |
"sequence": result["sequence"],
|
| 116 |
"labels": result["labels"],
|
|
|
|
| 5 |
from detoxify import Detoxify
|
| 6 |
import asyncio
|
| 7 |
from fastapi.concurrency import run_in_threadpool
|
| 8 |
+
from typing import List
|
| 9 |
|
| 10 |
class Guardrail:
|
| 11 |
def __init__(self):
|
|
|
|
| 61 |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 62 |
)
|
| 63 |
self.hypothesis_template = "This text is about {}"
|
|
|
|
| 64 |
|
| 65 |
+
async def classify(self, text, labels):
|
| 66 |
return await run_in_threadpool(
|
| 67 |
self.classifier,
|
| 68 |
text,
|
| 69 |
+
labels,
|
| 70 |
hypothesis_template=self.hypothesis_template,
|
| 71 |
multi_label=False
|
| 72 |
)
|
| 73 |
|
| 74 |
+
class TopicBannerRequest(BaseModel):
|
| 75 |
+
prompt: str
|
| 76 |
+
labels: List[str]
|
| 77 |
+
|
| 78 |
class TopicBannerResult(BaseModel):
|
| 79 |
sequence: str
|
| 80 |
labels: list
|
|
|
|
| 112 |
raise HTTPException(status_code=500, detail=str(e))
|
| 113 |
|
| 114 |
@app.post("/api/models/TopicBanner/classify", response_model=TopicBannerResult)
|
| 115 |
+
async def classify_topic_banner(request: TopicBannerRequest):
|
| 116 |
try:
|
| 117 |
+
result = await topic_banner_classifier.classify(request.prompt, request.labels)
|
| 118 |
return {
|
| 119 |
"sequence": result["sequence"],
|
| 120 |
"labels": result["labels"],
|