Spaces:
Runtime error
Runtime error
polinaeterna
commited on
Commit
·
2bd0078
1
Parent(s):
906b0be
update app
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ class QualityModel(nn.Module, PyTorchModelHubMixin):
|
|
| 24 |
outputs = self.fc(dropped)
|
| 25 |
return torch.softmax(outputs[:, 0, :], dim=1)
|
| 26 |
|
| 27 |
-
device = "cuda"
|
| 28 |
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
|
| 29 |
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
|
| 30 |
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
|
|
@@ -44,7 +44,8 @@ def predict(texts: list[str]):
|
|
| 44 |
return predicted_domains
|
| 45 |
|
| 46 |
|
| 47 |
-
def run_quality_check(dataset,
|
|
|
|
| 48 |
data = pl.read_parquet(f"hf://datasets/{dataset}@parquet~/{config}/train/0000.parquet", columns=[column])
|
| 49 |
texts = data[column].tolist()
|
| 50 |
predictions = predict(texts[:n_samples])
|
|
@@ -65,12 +66,12 @@ with gr.Blocks() as demo:
|
|
| 65 |
search_type="dataset",
|
| 66 |
value="HuggingFaceFW/fineweb",
|
| 67 |
)
|
| 68 |
-
config_name = "default"
|
| 69 |
@gr.render(inputs=dataset_name)
|
| 70 |
def embed(name):
|
| 71 |
html_code = f"""
|
| 72 |
<iframe
|
| 73 |
-
src="https://huggingface.co/datasets/{name}/embed/viewer/
|
| 74 |
frameborder="0"
|
| 75 |
width="100%"
|
| 76 |
height="700px"
|
|
@@ -82,5 +83,5 @@ with gr.Blocks() as demo:
|
|
| 82 |
gr_check_btn = gr.Button("Check Dataset")
|
| 83 |
# plot = gr.BarPlot()
|
| 84 |
df = gr.DataFrame(visible=False)
|
| 85 |
-
gr_check_btn.click(run_quality_check, inputs=[dataset_name,
|
| 86 |
gr.BarPlot(df)
|
|
|
|
| 24 |
outputs = self.fc(dropped)
|
| 25 |
return torch.softmax(outputs[:, 0, :], dim=1)
|
| 26 |
|
| 27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
|
| 29 |
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
|
| 30 |
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
|
|
|
|
| 44 |
return predicted_domains
|
| 45 |
|
| 46 |
|
| 47 |
+
def run_quality_check(dataset, column, n_samples):
|
| 48 |
+
config = "default"
|
| 49 |
data = pl.read_parquet(f"hf://datasets/{dataset}@parquet~/{config}/train/0000.parquet", columns=[column])
|
| 50 |
texts = data[column].tolist()
|
| 51 |
predictions = predict(texts[:n_samples])
|
|
|
|
| 66 |
search_type="dataset",
|
| 67 |
value="HuggingFaceFW/fineweb",
|
| 68 |
)
|
| 69 |
+
# config_name = "default" # TODO: user input
|
| 70 |
@gr.render(inputs=dataset_name)
|
| 71 |
def embed(name):
|
| 72 |
html_code = f"""
|
| 73 |
<iframe
|
| 74 |
+
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
|
| 75 |
frameborder="0"
|
| 76 |
width="100%"
|
| 77 |
height="700px"
|
|
|
|
| 83 |
gr_check_btn = gr.Button("Check Dataset")
|
| 84 |
# plot = gr.BarPlot()
|
| 85 |
df = gr.DataFrame(visible=False)
|
| 86 |
+
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, n_samples], outputs=[df])
|
| 87 |
gr.BarPlot(df)
|