|
|
from dotenv import load_dotenv |
|
|
from pathlib import Path |
|
|
import os |
|
|
import json |
|
|
import pandas as pd |
|
|
from huggingface_hub import InferenceClient |
|
|
from rich import print |
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
DATA_PATH = Path(__file__).parent.parent / "data" |
|
|
RESULTS_PATH = DATA_PATH / "results" |
|
|
IMAGES_PATH = DATA_PATH / "imgs" |
|
|
|
|
|
with open(DATA_PATH / "annotations.json", "r") as f: |
|
|
annotations = json.load(f) |
|
|
|
|
|
print(f"[bold]Loaded {len(annotations)} annotations from Label Studio[/bold]") |
|
|
|
|
|
|
|
|
|
|
|
annotated_images = [] |
|
|
for ann in annotations: |
|
|
file_upload = ann.get("file_upload") |
|
|
if not file_upload: |
|
|
continue |
|
|
|
|
|
|
|
|
choice = None |
|
|
if ann.get("annotations") and len(ann["annotations"]) > 0: |
|
|
result = ann["annotations"][0].get("result", []) |
|
|
if result and len(result) > 0: |
|
|
choices = result[0].get("value", {}).get("choices", []) |
|
|
if choices: |
|
|
choice = choices[0] |
|
|
|
|
|
|
|
|
file_path = IMAGES_PATH / file_upload |
|
|
|
|
|
annotated_images.append({ |
|
|
"file_upload": file_upload, |
|
|
"file_path": file_path, |
|
|
"choice": choice, |
|
|
"annotation_id": ann.get("id") |
|
|
}) |
|
|
|
|
|
|
|
|
existing_files = [img for img in annotated_images if img["file_path"].exists()] |
|
|
print( |
|
|
f"[bold]Found {len(existing_files)}/{len(annotated_images)} annotated image files[/bold]") |
|
|
|
|
|
|
|
|
client = InferenceClient( |
|
|
provider="hf-inference", |
|
|
api_key=os.environ.get("HF_TOKEN"), |
|
|
) |
|
|
|
|
|
if not os.environ.get("HF_TOKEN"): |
|
|
raise ValueError( |
|
|
"HF_TOKEN environment variable not set. Please set it in .env file.") |
|
|
|
|
|
|
|
|
images_to_process = [ |
|
|
img for img in annotated_images if img["file_path"].exists()] |
|
|
|
|
|
if not images_to_process: |
|
|
print("[red]✗ No annotated image files found![/red]") |
|
|
print("[yellow]Please copy files from Label Studio media directory first.[/yellow]") |
|
|
print(f"[dim]Expected location: {IMAGES_PATH}[/dim]") |
|
|
|
|
|
predictions = [] |
|
|
errors = [] |
|
|
|
|
|
print(f"[bold]Processing {len(images_to_process)} annotated images...[/bold]") |
|
|
|
|
|
with Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(), |
|
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
|
|
console=None, |
|
|
) as progress: |
|
|
task = progress.add_task("Classifying images", |
|
|
total=len(images_to_process)) |
|
|
|
|
|
for ann_img in images_to_process: |
|
|
file_path = ann_img["file_path"] |
|
|
file_upload = ann_img["file_upload"] |
|
|
choice = ann_img.get("choice") |
|
|
annotation_id = ann_img.get("annotation_id") |
|
|
|
|
|
try: |
|
|
if not file_path or not file_path.exists(): |
|
|
errors.append({ |
|
|
"annotation_id": annotation_id, |
|
|
"file_upload": file_upload, |
|
|
"error": "File not found" |
|
|
}) |
|
|
progress.update(task, advance=1) |
|
|
continue |
|
|
|
|
|
|
|
|
output = client.image_classification( |
|
|
str(file_path), |
|
|
model="Falconsai/nsfw_image_detection" |
|
|
) |
|
|
|
|
|
|
|
|
result = { |
|
|
"annotation_id": annotation_id, |
|
|
"file_upload": file_upload, |
|
|
"actual_filename": file_path.name, |
|
|
"label_studio_choice": choice, |
|
|
**{f"label_{i}": pred["label"] for i, pred in enumerate(output)}, |
|
|
**{f"score_{i}": pred["score"] for i, pred in enumerate(output)} |
|
|
} |
|
|
predictions.append(result) |
|
|
|
|
|
except Exception as e: |
|
|
errors.append({ |
|
|
"annotation_id": annotation_id, |
|
|
"file_upload": file_upload, |
|
|
"error": str(e) |
|
|
}) |
|
|
print(f"[red]Error processing {file_upload}: {e}[/red]") |
|
|
|
|
|
finally: |
|
|
progress.update(task, advance=1) |
|
|
|
|
|
|
|
|
if predictions: |
|
|
predictions_df = pd.DataFrame(predictions) |
|
|
predictions_df.to_csv(RESULTS_PATH / "baseline.csv", index=False) |
|
|
print( |
|
|
f"[green]✓ Saved {len(predictions)} predictions to baseline.csv[/green]") |
|
|
else: |
|
|
print("[red]✗ No predictions generated[/red]") |
|
|
|
|
|
|
|
|
if errors: |
|
|
errors_df = pd.DataFrame(errors) |
|
|
errors_df.to_csv(DATA_PATH / "falcons_errors.csv", index=False) |
|
|
print( |
|
|
f"[yellow]⚠ Saved {len(errors)} errors to falcons_errors.csv[/yellow]") |
|
|
|