File size: 2,647 Bytes
d0b2e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e6b325
d0b2e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import io
import pandas as pd
import numpy as np
import string
from uuid import uuid4
import os.path as osp
import base64
from PIL import Image
import sys




import os
import pandas as pd
from huggingface_hub import hf_hub_download
import streamlit as st

from config import *



def normalize_label(s: str) -> str:
    return " ".join(s.strip().lower().split())

@st.cache_data(show_spinner=False)
def load_private_tsv(filename: str) -> pd.DataFrame:
    """Download a TSV file from a private HF dataset repo."""
    local_path = hf_hub_download(
        repo_id=PRIVATE_DATASET_REPO,
        repo_type="dataset",
        filename=filename,
        token=HF_TOKEN,
    )
    df = pd.read_csv(local_path, sep="\t")
    df = df[["index","image", "answer"]].dropna()
    df["answer_norm"] = df["answer"].str.strip().str.lower()
    # enforce string ids to avoid type mismatches
    df["index"] = df["index"].astype(str)
    return df


def load_dataset_from_tsv(upload) -> pd.DataFrame:
    df = pd.read_csv(upload, sep="\t")
    required = {"index", "image", "answer"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"TSV must contain {sorted(required)}. Missing: {sorted(missing)}")

    df = df[["index", "image", "answer"]].dropna()
    df["answer_norm"] = df["answer"].apply(normalize_label)
    # enforce string ids to avoid type mismatches
    df["index"] = df["index"].astype(str)
    return df

class ParseError(Exception):
    pass





def parse_prompt1_indices(text: str) -> List[int]:
    nums = re.findall(r"[1-9]", text)
    return sorted(set(int(n) for n in nums))


def parse_prompt_1(text: str, target:Optional[str]) -> bool:
    t = normalize_label(text)
    if t in {"yes", "y"}: return True
    if t in {"no", "n"}: return False
    if t.startswith("yes"): return True
    if t.startswith("no"): return False
    raise ParseError("Unclear yes/no response")


def parse_prompt_2(text: str, target: str) -> bool:
    return text == target #normalize_label(text) == normalize_label(target)




def chunk(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i+n]




def encode_base64_image(image: Image.Image) -> str:
    buf = io.BytesIO()
    image.save(buf, format="PNG")            # or "PNG"/"WEBP" as you choose
    img_bytes = buf.getvalue()
    data_b64 = base64.b64encode(img_bytes).decode("ascii")    
    return data_b64

def decode_base64_image(b64: str) -> Image.Image:
    if "," in b64 and b64.strip().lower().startswith("data:"):
        b64 = b64.split(",", 1)[1]
    data = base64.b64decode(b64)
    return Image.open(io.BytesIO(data)).convert("RGB")