heymenn's picture
Update backend/data_service.py
c1442ee verified
import os
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import login
import logging
from typing import List, Optional, Dict, Any
from dotenv import load_dotenv
logger = logging.getLogger(__name__)
# Load envs
load_dotenv()
load_dotenv("../.env.local")
class DataService:
def __init__(self):
self.hf_token = os.getenv("HF_TOKEN")
self.dataset_name = os.getenv("HF_DATASET_NAME")
if not self.hf_token or not self.dataset_name:
logger.error("HF_TOKEN or HF_DATASET_NAME not set via environment variables.")
# We might want to raise an error here or handle it gracefully if running locally without HF
# For now, we'll log error.
if self.hf_token:
login(token=self.hf_token)
self.configs = ["files", "refined", "patterns", "results"]
self.data: Dict[str, pd.DataFrame] = {}
self._load_data()
def _load_data(self):
"""Loads data from HF Hub for each config. Initializes empty if not found."""
for config in self.configs:
try:
# trust_remote_code=True is sometimes needed, but for simple datasets usually not.
# using split="train" by default as load_dataset returns a DatasetDict if split not specified
ds = load_dataset(self.dataset_name, config, split="train")
self.data[config] = ds.to_pandas()
logger.info(f"Loaded config '{config}' with {len(self.data[config])} rows.")
except Exception as e:
logger.warning(f"Could not load config '{config}' from HF: {e}. Initializing empty.")
self.data[config] = pd.DataFrame()
def _save(self, config_name: str):
"""Pushes the specific config DataFrame to HF Hub."""
if not self.hf_token or not self.dataset_name:
logger.warning("Skipping save to HF: Credentials missing.")
return
try:
df = self.data[config_name]
# Convert DataFrame to Dataset
ds = Dataset.from_pandas(df)
# Push to hub
# We need to preserve the columns.
ds.push_to_hub(self.dataset_name, config_name=config_name, token=self.hf_token)
logger.info(f"Saved config '{config_name}' to HF Hub.")
except Exception as e:
logger.error(f"Failed to save config '{config_name}': {e}")
# --- Schema Helpers ---
# These ensure we have the right columns even if empty
def _ensure_columns(self, config, columns):
if self.data[config].empty:
self.data[config] = pd.DataFrame(columns=columns)
else:
# Add missing columns if any
for col in columns:
if col not in self.data[config].columns:
self.data[config][col] = None
# --- File Operations ---
def get_all_files(self) -> List[Dict[str, Any]]:
if self.data["files"].empty:
return []
return self.data["files"].to_dict(orient="records")
def get_file_content(self, file_id: str) -> Optional[str]:
df = self.data["files"]
if df.empty: return None
row = df[df["file_id"] == file_id]
if not row.empty:
return row.iloc[0]["content"]
return None
def add_file(self, file_data: Dict[str, Any]):
self._ensure_columns("files", ["file_id", "working_group", "meeting", "type", "status", "agenda_item", "content", "filename", "timestamp"])
df = self.data["files"]
# Check if exists
if not df.empty:
file_id = file_data["file_id"]
df = df[df["file_id"] != file_id]
# Add new row
new_row = pd.DataFrame([file_data])
self.data["files"] = pd.concat([df, new_row], ignore_index=True)
self._save("files")
# --- Refined Operations ---
def get_refined_output(self, file_id: str) -> Optional[str]:
df = self.data["refined"]
if df.empty: return None
row = df[df["file_id"] == file_id]
if not row.empty:
return row.iloc[0]["refined_output"]
return None
def add_refined(self, file_id: str, refined_output: str) -> int:
self._ensure_columns("refined", ["refined_id", "refined_output", "file_id"])
df = self.data["refined"]
# Generate ID
next_id = 1
if not df.empty:
# check max. If refined_id is not numeric (e.g. None), handle it.
# Assuming it is numeric as per SQLite schema
max_id = pd.to_numeric(df["refined_id"]).max()
if not pd.isna(max_id):
next_id = int(max_id) + 1
new_row = pd.DataFrame([{
"refined_id": next_id,
"refined_output": refined_output,
"file_id": file_id
}])
self.data["refined"] = pd.concat([df, new_row], ignore_index=True)
self._save("refined")
return next_id
def get_refined_by_file_id(self, file_id: str):
df = self.data["refined"]
if df.empty: return None
row = df[df["file_id"] == file_id]
if not row.empty:
return row.iloc[0].to_dict()
return None
# --- Pattern Operations ---
def get_patterns(self) -> List[Dict[str, Any]]:
if self.data["patterns"].empty:
return []
return self.data["patterns"].to_dict(orient="records")
def get_pattern(self, pattern_id: int):
df = self.data["patterns"]
if df.empty: return None
row = df[df["pattern_id"] == pattern_id]
if not row.empty:
return row.iloc[0].to_dict()
return None
def add_pattern(self, pattern_name: str, prompt: str) -> int:
self._ensure_columns("patterns", ["pattern_id", "pattern_name", "prompt"])
df = self.data["patterns"]
next_id = 1
if not df.empty:
max_id = pd.to_numeric(df["pattern_id"]).max()
if not pd.isna(max_id):
next_id = int(max_id) + 1
new_row = pd.DataFrame([{
"pattern_id": next_id,
"pattern_name": pattern_name,
"prompt": prompt
}])
self.data["patterns"] = pd.concat([df, new_row], ignore_index=True)
self._save("patterns")
return next_id
def update_pattern(self, pattern_id: int, pattern_name: str, prompt: str):
df = self.data["patterns"]
if df.empty: return False
# Check if exists
if pattern_id not in df["pattern_id"].values:
return False
# Update
self.data["patterns"].loc[df["pattern_id"] == pattern_id, ["pattern_name", "prompt"]] = [pattern_name, prompt]
self._save("patterns")
return True
# --- Result Operations ---
def get_existing_result(self, file_id: str):
"""
Equivalent to:
SELECT ... FROM result r JOIN refined ref ... WHERE refined.file_id = ?
"""
# First get refined_id for file_id
ref_row = self.get_refined_by_file_id(file_id)
file_df = self.data["files"]
file_name = "Unknown File"
if not file_df.empty:
f_row = file_df[file_df["file_id"] == file_id]
if not f_row.empty:
file_name = f_row.iloc[0]["filename"]
if not ref_row:
return None, None, file_name
refined_id = ref_row["refined_id"]
# Search in results
res_df = self.data["results"]
if res_df.empty:
return None, refined_id, file_name
# Filter where refined_id matches
# Note: result has refined_id
match = res_df[res_df["refined_id"] == refined_id]
if match.empty:
return None, refined_id, file_name
# Use the LAST result if multiple? Original SQL used simple join, usually implies 1-to-1 or fetchone
# We'll take the first one or last one.
result_row = match.iloc[-1].to_dict() # latest
# Need pattern name
pat_df = self.data["patterns"]
pattern_name = "Unknown"
if not pat_df.empty and "pattern_id" in result_row:
pat_match = pat_df[pat_df["pattern_id"] == result_row["pattern_id"]]
if not pat_match.empty:
pattern_name = pat_match.iloc[0]["pattern_name"]
result_row["pattern_name"] = pattern_name
# normalize keys to match what main.py expects (content vs result_content)
# Main.py expects 'content' key for result_content
result_row["content"] = result_row.get("result_content")
return result_row, refined_id, file_name
def add_result(self, pattern_id: int, refined_id: int, result_content: str, methodology: str, context: str, problem: str, classification: str = "UNCLASSIFIED") -> int:
self._ensure_columns("results", ["result_id", "pattern_id", "refined_id", "result_content", "methodology", "context", "problem", "classification"])
df = self.data["results"]
next_id = 1
if not df.empty:
max_id = pd.to_numeric(df["result_id"]).max()
if not pd.isna(max_id):
next_id = int(max_id) + 1
new_row = pd.DataFrame([{
"result_id": next_id,
"pattern_id": pattern_id,
"refined_id": refined_id,
"result_content": result_content,
"methodology": methodology,
"context": context,
"problem": problem,
"classification": classification
}])
self.data["results"] = pd.concat([df, new_row], ignore_index=True)
self._save("results")
return next_id
def update_classification(self, result_id: int, classification: str):
df = self.data["results"]
if df.empty: raise Exception("No results found")
if result_id not in df["result_id"].values:
return False
self.data["results"].loc[df["result_id"] == result_id, "classification"] = classification
self._save("results")
return True
def get_all_results_joined(self):
"""
Joins results, refined, file, pattern
"""
if self.data["results"].empty:
return []
res_df = self.data["results"].copy()
# Join Patterns
pat_df = self.data["patterns"]
if not pat_df.empty:
res_df = res_df.merge(pat_df[["pattern_id", "pattern_name"]], on="pattern_id", how="left")
# Join Refined
ref_df = self.data["refined"]
if not ref_df.empty:
res_df = res_df.merge(ref_df[["refined_id", "file_id"]], on="refined_id", how="left")
# Join File
def_file = self.data["files"]
if not def_file.empty:
res_df = res_df.merge(def_file[["file_id", "filename"]], on="file_id", how="left")
# Select/Rename for output
# Mappings based on API: id, file_name, content, classification, pattern_name, etc.
out = []
for _, row in res_df.iterrows():
out.append({
"id": row.get("result_id"),
"file_name": row.get("filename"),
"content": row.get("result_content"),
"classification": row.get("classification"),
"pattern_name": row.get("pattern_name"),
"methodology": row.get("methodology"),
"context": row.get("context"),
"problem": row.get("problem")
})
# sort desc by id
out.sort(key=lambda x: x["id"] or 0, reverse=True)
return out