Grapher / mdl_utils.py
Pragatik771's picture
Upload 13 files
6ea0c39 verified
import os
import requests
from pydantic import Field
import logging
from datetime import datetime
from typing import List, Optional, Any
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
def get_llm_provider(provider: Optional[str] = None):
"""
Get an LLM provider instance.
Args:
provider: Provider name ('deepseek' or None for default)
Returns:
An LLM provider instance with a generate() method
"""
provider = (provider or "deepseek").lower()
if provider == "deepseek":
return DeepSeekProvider()
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
DEEPSEEK_API_URL = "https://api.deepseek.com/v1/chat/completions"
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
class DeepSeekProvider:
"""DeepSeek LLM provider via API."""
def __init__(self, model_name: str = "deepseek-chat"):
"""
Initialize the DeepSeek API provider.
Args:
model_name: Model to use (default: deepseek-chat)
"""
self.model_name = model_name
self.api_url = DEEPSEEK_API_URL
self.api_key = DEEPSEEK_API_KEY
if not self.api_key:
raise ValueError("DeepSeek API key not set. Please export DEEPSEEK_API_KEY.")
def generate(
self,
prompt: str,
max_tokens: int = 1000,
temperature: float = 0.7,
) -> str:
"""
Generate text using DeepSeek API.
Args:
prompt: Input prompt
max_tokens: Maximum tokens
temperature: Sampling temperature
Returns:
Generated text (string)
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model_name,
"messages": [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt},
],
"max_tokens": max_tokens,
"temperature": temperature,
}
try:
response = requests.post(self.api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"].strip()
except Exception as e:
logger.error(f"DeepSeek API error: {e}", exc_info=True)
raise
# ---------------------------------------------------------------------------
# Schema Models
# ---------------------------------------------------------------------------
class FieldDefinition(BaseModel):
"""Definition of a single dataset field."""
name: str
type: str
description: str = ""
nullable: bool = True
format: Optional[str] = None
enum: Optional[List[Any]] = None
unit: Optional[str] = None
example: Optional[Any] = None
class Constraint(BaseModel):
"""Dataset constraint definition."""
name: str
type: str
condition: Optional[str] = None
columns: Optional[List[str]] = None
class DatasetMDL(BaseModel):
dataset: str
description: str = ""
fields: List[FieldDefinition]
constraints: List[Constraint] = Field(default_factory=list) # ✅ safe default
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
version: str = "1.0"
class Config:
json_encoders = {
np.integer: int,
np.floating: float,
np.ndarray: lambda v: v.tolist(),
}
# ---------------------------------------------------------------------------
# Schema Inference
# ---------------------------------------------------------------------------
def infer_field_type(dtype) -> str:
"""Infer field type from pandas dtype."""
if pd.api.types.is_integer_dtype(dtype):
return "integer"
if pd.api.types.is_float_dtype(dtype):
return "float"
if pd.api.types.is_datetime64_any_dtype(dtype):
return "datetime"
if pd.api.types.is_bool_dtype(dtype):
return "boolean"
if pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(dtype):
return "string"
return "string"
def generate_mdl(df: pd.DataFrame, dataset_name: str) -> DatasetMDL:
"""
Generate a Model Definition Language (MDL) schema from a pandas DataFrame.
Args:
df: Input DataFrame
dataset_name: Name of the dataset
Returns:
DatasetMDL: Generated schema
"""
fields: List[FieldDefinition] = []
for column in df.columns:
dtype = df[column].dtype
field_type = infer_field_type(dtype)
field = FieldDefinition(
name=column,
type=field_type,
description=f"Column {column} of type {field_type}",
nullable=df[column].isna().any(),
)
if field_type == "datetime":
field.format = "ISO8601"
unique_values = df[column].dropna().unique()
if field_type == "string" and len(unique_values) <= 20:
field.enum = sorted(map(str, unique_values))
non_null_values = df[column].dropna()
if not non_null_values.empty:
field.example = str(non_null_values.iloc[0])
fields.append(field)
constraints: List[Constraint] = []
for col in [c for c in df.columns if "id" in c.lower() or "code" in c.lower()]:
if df[col].is_unique or df[col].nunique() == len(df):
constraints.append(Constraint(name=f"{col}_unique", type="unique", columns=[col]))
for col in df.columns:
if not df[col].isna().any():
constraints.append(Constraint(name=f"{col}_not_null", type="not_null", columns=[col]))
return DatasetMDL(
dataset=dataset_name,
description=f"Auto-generated schema for {dataset_name}",
fields=fields,
constraints=constraints,
)
# ---------------------------------------------------------------------------
# Conversions
# ---------------------------------------------------------------------------
def mdl_to_text(mdl: DatasetMDL) -> str:
"""Convert MDL schema to a formatted text string."""
lines = [
f"Dataset: {mdl.dataset}",
f"Description: {mdl.description}",
"\nFields:",
]
for field in mdl.fields:
desc = f" - {field.name}: {field.type}"
if field.description:
desc += f" - {field.description}"
desc += " (nullable)" if field.nullable else " (required)"
if field.enum:
enum_preview = ", ".join(field.enum[:5])
if len(field.enum) > 5:
enum_preview += f" and {len(field.enum) - 5} more..."
desc += f" [enum: {enum_preview}]"
if field.example:
desc += f" (e.g., {field.example})"
lines.append(desc)
if mdl.constraints:
lines.append("\nConstraints:")
for c in mdl.constraints:
if c.type == "unique":
lines.append(f" - {c.name}: Unique on {', '.join(c.columns or [])}")
elif c.type == "not_null":
lines.append(f" - {c.name}: Not null on {', '.join(c.columns or [])}")
elif c.condition:
lines.append(f" - {c.name}: {c.type} ({c.condition})")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# RAG and Embeddings
# ---------------------------------------------------------------------------
def get_rag_context(
query: str,
dataset_name: str,
top_k: int = 3,
score_threshold: float = 0.7,
db=None
) -> str:
"""
Retrieve relevant context using in-memory similarity.
This is a simplified version that doesn't require a database.
For a production system, consider using a proper vector database.
Args:
query: Query string
dataset_name: Dataset name (unused in this implementation)
top_k: Number of results to return
score_threshold: Minimum similarity score (0-1)
db: Kept for backward compatibility (unused)
Returns:
Context string with relevant information
"""
try:
# In a real implementation, you would use a proper vector database
# For now, return an empty string as we don't have a database
logger.warning("RAG context retrieval is not implemented without a database")
return ""
except Exception as e:
logger.error(f"Error in get_rag_context: {str(e)}")
return ""
# ---------------------------------------------------------------------------
# SQL Query Generation
# ---------------------------------------------------------------------------
def clean_sql(sql: str, mdl: DatasetMDL) -> str:
# Remove fallback SELECT * queries
if "SELECT" in sql and "LIMIT 100" in sql and "GROUP BY" not in sql:
cols = ", ".join([f'"{f.name}"' for f in mdl.fields])
return f'SELECT {cols} FROM "{mdl.dataset}" LIMIT 100;'
return sql.strip().rstrip(";") + ";"
# def generate_sql_query(
# natural_language_query: str,
# mdl: DatasetMDL,
# model_provider: Optional[str] = None,
# dataset_name: Optional[str] = None,
# use_rag: bool = False,
# top_k: int = 3
# ) -> str:
# """
# Convert a natural language query to SQL using MDL schema.
# Args:
# natural_language_query: Query string
# mdl: Dataset schema
# model_provider: LLM provider
# dataset_name: Dataset name (unused, kept for backward compatibility)
# use_rag: Not used, kept for backward compatibility
# top_k: Not used, kept for backward compatibility
# Returns:
# SQL query string
# """
# try:
# # Get the LLM provider
# llm = get_llm_provider(provider=model_provider)
# # Create prompt with schema
# prompt = f"""
# You are an expert SQL generator.
# Your task is to create the most appropriate SQL query for a given natural language question.
# Database schema:
# {mdl_to_text(mdl)}
# Guidelines:
# - Always choose the minimum set of columns needed to answer the question.
# - If the query asks for a ratio, percentage, distribution, or comparison,
# use GROUP BY with aggregation (COUNT, SUM, AVG, etc.).
# - For "pie chart", "ratio", or "distribution", return grouped counts or proportions.
# - Do not SELECT all columns unless explicitly requested.
# - Always alias aggregate columns with meaningful names (e.g., gender_count, total_users).
# - Use the exact dataset name: "{mdl.dataset}" as the table.
# Natural language query: {natural_language_query}
# SQL query:
# """
# sql = llm.generate(prompt).strip()
# resulting_sql=clean_sql(sql, mdl)
# return resulting_sql
# except Exception as e:
# logger.error(f"SQL generation error: {e}")
# cols = ", ".join([f'"{f.name}"' for f in mdl.fields])
# return f'SELECT {cols} FROM "{mdl.dataset}" LIMIT 100;'
def generate_sql_query(
natural_language_query: str,
mdl: DatasetMDL,
model_provider: Optional[str] = None,
dataset_name: Optional[str] = None,
use_rag: bool = False,
top_k: int = 3
) -> str:
"""
Convert a natural language query to SQL using MDL schema.
If the query does not align with the dataset schema, return "INVALID QUERY".
"""
try:
# Get the LLM provider
llm = get_llm_provider(provider=model_provider)
# Create prompt with schema + strict instructions
prompt = f"""
You are an expert SQL generator.
Your task is to create the most appropriate SQL query for a given natural language question.
Database schema:
{mdl_to_text(mdl)}
Guidelines:
- Only generate SQL queries that can be executed on this schema.
- If the natural language query refers to columns, tables, or concepts NOT present in the schema,
or is ambiguous and cannot be mapped with high confidence, output exactly:
INVALID QUERY
- Always choose the minimum set of columns needed to answer the question.
- If the query asks for a ratio, percentage, distribution, or comparison,
use GROUP BY with aggregation (COUNT, SUM, AVG, etc.).
- For "pie chart", "ratio", or "distribution", return grouped counts or proportions.
- Do not SELECT all columns unless explicitly requested.
- Always alias aggregate columns with meaningful names (e.g., gender_count, total_users).
- Use the exact dataset name: "{mdl.dataset}" as the table.
Natural language query: {natural_language_query}
Respond with either:
1. A valid SQL query based strictly on the schema, OR
2. The text "INVALID QUERY" (no explanation).
"""
sql = llm.generate(prompt).strip()
if sql == "INVALID QUERY":
logger.error("9999 Invalid query generated")
return sql
resulting_sql = clean_sql(sql, mdl)
return resulting_sql
except Exception as e:
logger.error(f"SQL generation error: {e}")
cols = ", ".join([f'"{f.name}"' for f in mdl.fields])
return f'SELECT {cols} FROM "{mdl.dataset}" LIMIT 100;'
# ---------------------------------------------------------------------------
# CSV to DataFrame
# ---------------------------------------------------------------------------
def process_csv_to_df(csv_content: str) -> pd.DataFrame:
"""Convert CSV content into a pandas DataFrame."""
from io import StringIO
try:
return pd.read_csv(StringIO(csv_content))
except Exception:
return pd.read_csv(StringIO(csv_content), engine="python")