import os import io import pandas as pd import numpy as np import string from uuid import uuid4 import os.path as osp import base64 from PIL import Image import sys import os import pandas as pd from huggingface_hub import hf_hub_download import streamlit as st from config import * def normalize_label(s: str) -> str: return " ".join(s.strip().lower().split()) @st.cache_data(show_spinner=False) def load_private_tsv(filename: str) -> pd.DataFrame: """Download a TSV file from a private HF dataset repo.""" local_path = hf_hub_download( repo_id=PRIVATE_DATASET_REPO, repo_type="dataset", filename=filename, token=HF_TOKEN, ) df = pd.read_csv(local_path, sep="\t") df = df[["index","image", "answer"]].dropna() df["answer_norm"] = df["answer"].str.strip().str.lower() # enforce string ids to avoid type mismatches df["index"] = df["index"].astype(str) return df def load_dataset_from_tsv(upload) -> pd.DataFrame: df = pd.read_csv(upload, sep="\t") required = {"index", "image", "answer"} missing = required - set(df.columns) if missing: raise ValueError(f"TSV must contain {sorted(required)}. Missing: {sorted(missing)}") df = df[["index", "image", "answer"]].dropna() df["answer_norm"] = df["answer"].apply(normalize_label) # enforce string ids to avoid type mismatches df["index"] = df["index"].astype(str) return df class ParseError(Exception): pass def parse_prompt1_indices(text: str) -> List[int]: nums = re.findall(r"[1-9]", text) return sorted(set(int(n) for n in nums)) def parse_prompt_1(text: str, target:Optional[str]) -> bool: t = normalize_label(text) if t in {"yes", "y"}: return True if t in {"no", "n"}: return False if t.startswith("yes"): return True if t.startswith("no"): return False raise ParseError("Unclear yes/no response") def parse_prompt_2(text: str, target: str) -> bool: return text == target #normalize_label(text) == normalize_label(target) def chunk(lst, n): for i in range(0, len(lst), n): yield lst[i:i+n] def encode_base64_image(image: Image.Image) -> str: buf = io.BytesIO() image.save(buf, format="PNG") # or "PNG"/"WEBP" as you choose img_bytes = buf.getvalue() data_b64 = base64.b64encode(img_bytes).decode("ascii") return data_b64 def decode_base64_image(b64: str) -> Image.Image: if "," in b64 and b64.strip().lower().startswith("data:"): b64 = b64.split(",", 1)[1] data = base64.b64decode(b64) return Image.open(io.BytesIO(data)).convert("RGB")