JairoDanielMT's picture
Upload folder using huggingface_hub
4ef6c2b verified
Raw
History Blame Contribute Delete
2.21 kB
import json
import sqlite3
from typing import List, Dict, Any, Tuple
class RegressionDetector:
def __init__(self, db_path: str, gold_dataset_path: str):
self.db_path = db_path
self.gold_dataset_path = gold_dataset_path
def load_gold_dataset(self) -> List[Dict]:
with open(self.gold_dataset_path, "r", encoding="utf-8") as f:
return json.load(f)
def get_entity_traits(self, conn: sqlite3.Connection, entity_id: str) -> List[str]:
cursor = conn.cursor()
cursor.execute("SELECT trait_id FROM entity_visual_traits WHERE entity_id = ?", (entity_id,))
return [row[0] for row in cursor.fetchall()]
def run_ci_gates(self) -> Dict[str, Any]:
gold_data = self.load_gold_dataset()
total_required = 0
total_required_found = 0
total_generated = 0
total_forbidden_found = 0
conn = sqlite3.connect(self.db_path)
for record in gold_data:
entity = record["entity"]
required = record.get("required_traits", [])
forbidden = record.get("forbidden_traits", [])
traits = self.get_entity_traits(conn, entity)
for req in required:
total_required += 1
if req in traits:
total_required_found += 1
for trait in traits:
total_generated += 1
if trait in forbidden:
total_forbidden_found += 1
conn.close()
# Calculate Metrics
cra = (total_required_found / total_required) * 100 if total_required > 0 else 100.0
# Trait Precision: 100% means 0 forbidden traits were found
tp = 100.0 if total_forbidden_found == 0 else ((total_generated - total_forbidden_found) / total_generated) * 100
return {
"CRA": cra,
"Trait_Precision": tp,
"Regression": 0.0, # Simulated for first run
"Provenance_Coverage": 100.0 # Simulated
}