RecaptchaLLM / src /utils.py
chris1nexus
Update prompt
1e6b325
import os
import io
import pandas as pd
import numpy as np
import string
from uuid import uuid4
import os.path as osp
import base64
from PIL import Image
import sys
import os
import pandas as pd
from huggingface_hub import hf_hub_download
import streamlit as st
from config import *
def normalize_label(s: str) -> str:
return " ".join(s.strip().lower().split())
@st.cache_data(show_spinner=False)
def load_private_tsv(filename: str) -> pd.DataFrame:
"""Download a TSV file from a private HF dataset repo."""
local_path = hf_hub_download(
repo_id=PRIVATE_DATASET_REPO,
repo_type="dataset",
filename=filename,
token=HF_TOKEN,
)
df = pd.read_csv(local_path, sep="\t")
df = df[["index","image", "answer"]].dropna()
df["answer_norm"] = df["answer"].str.strip().str.lower()
# enforce string ids to avoid type mismatches
df["index"] = df["index"].astype(str)
return df
def load_dataset_from_tsv(upload) -> pd.DataFrame:
df = pd.read_csv(upload, sep="\t")
required = {"index", "image", "answer"}
missing = required - set(df.columns)
if missing:
raise ValueError(f"TSV must contain {sorted(required)}. Missing: {sorted(missing)}")
df = df[["index", "image", "answer"]].dropna()
df["answer_norm"] = df["answer"].apply(normalize_label)
# enforce string ids to avoid type mismatches
df["index"] = df["index"].astype(str)
return df
class ParseError(Exception):
pass
def parse_prompt1_indices(text: str) -> List[int]:
nums = re.findall(r"[1-9]", text)
return sorted(set(int(n) for n in nums))
def parse_prompt_1(text: str, target:Optional[str]) -> bool:
t = normalize_label(text)
if t in {"yes", "y"}: return True
if t in {"no", "n"}: return False
if t.startswith("yes"): return True
if t.startswith("no"): return False
raise ParseError("Unclear yes/no response")
def parse_prompt_2(text: str, target: str) -> bool:
return text == target #normalize_label(text) == normalize_label(target)
def chunk(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
def encode_base64_image(image: Image.Image) -> str:
buf = io.BytesIO()
image.save(buf, format="PNG") # or "PNG"/"WEBP" as you choose
img_bytes = buf.getvalue()
data_b64 = base64.b64encode(img_bytes).decode("ascii")
return data_b64
def decode_base64_image(b64: str) -> Image.Image:
if "," in b64 and b64.strip().lower().startswith("data:"):
b64 = b64.split(",", 1)[1]
data = base64.b64decode(b64)
return Image.open(io.BytesIO(data)).convert("RGB")