File size: 4,760 Bytes
9894d76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from dotenv import load_dotenv
from pathlib import Path
import os
import json
import pandas as pd
from huggingface_hub import InferenceClient
from rich import print
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
load_dotenv()
DATA_PATH = Path(__file__).parent.parent / "data"
RESULTS_PATH = DATA_PATH / "results"
IMAGES_PATH = DATA_PATH / "imgs"
with open(DATA_PATH / "annotations.json", "r") as f:
annotations = json.load(f)
print(f"[bold]Loaded {len(annotations)} annotations from Label Studio[/bold]")
# Extract annotated images with their labels
# Files should now be in data/imgs/ with their file_upload names
annotated_images = []
for ann in annotations:
file_upload = ann.get("file_upload")
if not file_upload:
continue
# Get the annotation choice
choice = None
if ann.get("annotations") and len(ann["annotations"]) > 0:
result = ann["annotations"][0].get("result", [])
if result and len(result) > 0:
choices = result[0].get("value", {}).get("choices", [])
if choices:
choice = choices[0]
# File should be in data/imgs/ with the file_upload name
file_path = IMAGES_PATH / file_upload
annotated_images.append({
"file_upload": file_upload,
"file_path": file_path,
"choice": choice,
"annotation_id": ann.get("id")
})
# Check how many files actually exist
existing_files = [img for img in annotated_images if img["file_path"].exists()]
print(
f"[bold]Found {len(existing_files)}/{len(annotated_images)} annotated image files[/bold]")
# Initialize client
client = InferenceClient(
provider="hf-inference",
api_key=os.environ.get("HF_TOKEN"),
)
if not os.environ.get("HF_TOKEN"):
raise ValueError(
"HF_TOKEN environment variable not set. Please set it in .env file.")
# Filter to only images that exist
images_to_process = [
img for img in annotated_images if img["file_path"].exists()]
if not images_to_process:
print("[red]✗ No annotated image files found![/red]")
print("[yellow]Please copy files from Label Studio media directory first.[/yellow]")
print(f"[dim]Expected location: {IMAGES_PATH}[/dim]")
predictions = []
errors = []
print(f"[bold]Processing {len(images_to_process)} annotated images...[/bold]")
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=None,
) as progress:
task = progress.add_task("Classifying images",
total=len(images_to_process))
for ann_img in images_to_process:
file_path = ann_img["file_path"]
file_upload = ann_img["file_upload"]
choice = ann_img.get("choice")
annotation_id = ann_img.get("annotation_id")
try:
if not file_path or not file_path.exists():
errors.append({
"annotation_id": annotation_id,
"file_upload": file_upload,
"error": "File not found"
})
progress.update(task, advance=1)
continue
# Classify image
output = client.image_classification(
str(file_path),
model="Falconsai/nsfw_image_detection"
)
# Flatten the output (list of dicts) and add metadata
result = {
"annotation_id": annotation_id,
"file_upload": file_upload,
"actual_filename": file_path.name,
"label_studio_choice": choice,
**{f"label_{i}": pred["label"] for i, pred in enumerate(output)},
**{f"score_{i}": pred["score"] for i, pred in enumerate(output)}
}
predictions.append(result)
except Exception as e:
errors.append({
"annotation_id": annotation_id,
"file_upload": file_upload,
"error": str(e)
})
print(f"[red]Error processing {file_upload}: {e}[/red]")
finally:
progress.update(task, advance=1)
# Save predictions
if predictions:
predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv(RESULTS_PATH / "baseline.csv", index=False)
print(
f"[green]✓ Saved {len(predictions)} predictions to baseline.csv[/green]")
else:
print("[red]✗ No predictions generated[/red]")
# Save errors if any
if errors:
errors_df = pd.DataFrame(errors)
errors_df.to_csv(DATA_PATH / "falcons_errors.csv", index=False)
print(
f"[yellow]⚠ Saved {len(errors)} errors to falcons_errors.csv[/yellow]")
|