Spaces:
Sleeping
Sleeping
loplopez commited on
Commit ·
8ad2ef4
1
Parent(s): c8df78e
tests on classification results
Browse files- app/app.py +2 -2
- app/modules/classify.py +5 -4
- app/modules/redistribute.py +0 -2
app/app.py
CHANGED
|
@@ -41,9 +41,10 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
|
|
| 41 |
items = input_data.items
|
| 42 |
# TODO consider sampling them?
|
| 43 |
|
| 44 |
-
print(items)
|
| 45 |
reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
|
| 46 |
#reranked_ids = [ for id_ in reranked_ids]
|
|
|
|
|
|
|
| 47 |
|
| 48 |
user_in_db = user_db.get_user(user_id=user)
|
| 49 |
|
|
@@ -97,6 +98,5 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
|
|
| 97 |
|
| 98 |
# no civic content to boost on
|
| 99 |
else:
|
| 100 |
-
print("there")
|
| 101 |
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
| 102 |
|
|
|
|
| 41 |
items = input_data.items
|
| 42 |
# TODO consider sampling them?
|
| 43 |
|
|
|
|
| 44 |
reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
|
| 45 |
#reranked_ids = [ for id_ in reranked_ids]
|
| 46 |
+
print("Receiving boost on: ", first_topic)
|
| 47 |
+
print("Position: ", insertion_pos)
|
| 48 |
|
| 49 |
user_in_db = user_db.get_user(user_id=user)
|
| 50 |
|
|
|
|
| 98 |
|
| 99 |
# no civic content to boost on
|
| 100 |
else:
|
|
|
|
| 101 |
return RankingResponse(ranked_ids=reranked_ids, new_items=[])
|
| 102 |
|
app/modules/classify.py
CHANGED
|
@@ -10,7 +10,7 @@ except:
|
|
| 10 |
print("No GPU available, running on CPU")
|
| 11 |
device = None
|
| 12 |
|
| 13 |
-
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
| 14 |
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
|
| 15 |
|
| 16 |
label_map = {
|
|
@@ -49,6 +49,7 @@ def classify(texts: List[str], labels: List[str]):
|
|
| 49 |
# Iterate through each text to check for special cases
|
| 50 |
for index, text in enumerate(texts):
|
| 51 |
if text == "NON-VALID":
|
|
|
|
| 52 |
# If text is "X", directly assign the label and score
|
| 53 |
results.append({
|
| 54 |
"sequence": text,
|
|
@@ -57,16 +58,16 @@ def classify(texts: List[str], labels: List[str]):
|
|
| 57 |
})
|
| 58 |
else:
|
| 59 |
# Otherwise, prepare for model processing
|
|
|
|
| 60 |
model_texts.append(text)
|
| 61 |
model_indices.append(index)
|
| 62 |
|
| 63 |
if model_texts:
|
| 64 |
# Process texts through the model if there are any
|
| 65 |
-
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=
|
| 66 |
|
| 67 |
# Insert model results into the correct positions
|
| 68 |
for pred, idx in zip(predicted_labels, model_indices):
|
| 69 |
results.insert(idx, pred)
|
| 70 |
-
|
| 71 |
-
print(results)
|
| 72 |
return results
|
|
|
|
| 10 |
print("No GPU available, running on CPU")
|
| 11 |
device = None
|
| 12 |
|
| 13 |
+
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)
|
| 14 |
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
|
| 15 |
|
| 16 |
label_map = {
|
|
|
|
| 49 |
# Iterate through each text to check for special cases
|
| 50 |
for index, text in enumerate(texts):
|
| 51 |
if text == "NON-VALID":
|
| 52 |
+
print("NON-VALID TEXT!!", text)
|
| 53 |
# If text is "X", directly assign the label and score
|
| 54 |
results.append({
|
| 55 |
"sequence": text,
|
|
|
|
| 58 |
})
|
| 59 |
else:
|
| 60 |
# Otherwise, prepare for model processing
|
| 61 |
+
#print("- text =>", text)
|
| 62 |
model_texts.append(text)
|
| 63 |
model_indices.append(index)
|
| 64 |
|
| 65 |
if model_texts:
|
| 66 |
# Process texts through the model if there are any
|
| 67 |
+
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=32)
|
| 68 |
|
| 69 |
# Insert model results into the correct positions
|
| 70 |
for pred, idx in zip(predicted_labels, model_indices):
|
| 71 |
results.insert(idx, pred)
|
| 72 |
+
print([(r['labels'][0], r['sequence']) for r in results])
|
|
|
|
| 73 |
return results
|
app/modules/redistribute.py
CHANGED
|
@@ -24,9 +24,7 @@ def redistribute(platform, items):
|
|
| 24 |
mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
|
| 25 |
first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
|
| 26 |
# TODO include parent linking
|
| 27 |
-
print("OK--", predicted_labels)
|
| 28 |
reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
|
| 29 |
-
print(reranked_ids)
|
| 30 |
return reranked_ids, first_topic, insertion_pos
|
| 31 |
|
| 32 |
|
|
|
|
| 24 |
mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
|
| 25 |
first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
|
| 26 |
# TODO include parent linking
|
|
|
|
| 27 |
reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
|
|
|
|
| 28 |
return reranked_ids, first_topic, insertion_pos
|
| 29 |
|
| 30 |
|