Iostream-Li's picture
Add files using upload-large-folder tool
ff78003 verified
/**
* causal.adapter — 真查 PG `causal_edges` / `literature_buckets`。
*
* 支持过滤参数:
* - bucket_id 按 bucket_id 等值
* - bucket_ids 按 bucket_id 列表
* - bucket_key 按 literature_buckets.bucket_key (jsonb 等值,见 README)
* - edge_type 按 causal_edges.equation_type 过滤(E1..E6)
* - record_status / effect_status / paper_id
* - score_min / score_max — 按 pooled_estimate->>'value' 数值区间
*
* 返回 KnowledgeItem.payload 含:
* - edge / equation_type / record_status / effect_status / pooled_estimate
* - confidence_interval — 从 pooled_estimate JSON 抽出的 [lo, hi] 二元组
* - paper_count — 同 bucket 内 distinct paper_id 计数(GROUP BY 一次性算)
*/
import { and, eq, inArray, sql } from "drizzle-orm";
import { causalEdges, literatureBuckets, db } from "@workspace/db";
import { enforceBudget } from "./budget-guard.ts";
import { fingerprintParams, recordCall } from "./telemetry.ts";
import type {
KnowledgeAdapter,
KnowledgeItem,
KnowledgeQuery,
KnowledgeResult,
} from "./types.ts";
interface CausalParams {
bucket_id?: string;
bucket_ids?: string[];
bucket_key?: Record<string, unknown>;
edge_type?: string;
record_status?: string;
effect_status?: string;
paper_id?: string;
score_min?: number;
score_max?: number;
}
function readParams(p: Record<string, unknown>): CausalParams {
const isObj = (v: unknown): v is Record<string, unknown> =>
typeof v === "object" && v !== null && !Array.isArray(v);
return {
bucket_id:
typeof p["bucket_id"] === "string" ? (p["bucket_id"] as string) : undefined,
bucket_ids: Array.isArray(p["bucket_ids"])
? (p["bucket_ids"] as unknown[]).filter((x): x is string => typeof x === "string")
: undefined,
bucket_key: isObj(p["bucket_key"]) ? p["bucket_key"] : undefined,
edge_type:
typeof p["edge_type"] === "string" ? (p["edge_type"] as string) : undefined,
record_status:
typeof p["record_status"] === "string" ? (p["record_status"] as string) : undefined,
effect_status:
typeof p["effect_status"] === "string" ? (p["effect_status"] as string) : undefined,
paper_id:
typeof p["paper_id"] === "string" ? (p["paper_id"] as string) : undefined,
score_min:
typeof p["score_min"] === "number" ? (p["score_min"] as number) : undefined,
score_max:
typeof p["score_max"] === "number" ? (p["score_max"] as number) : undefined,
};
}
/**
* 从 pooled_estimate JSON 抽 confidence_interval。容错多种约定:
* { ci: [lo, hi] }
* { ci_low, ci_high }
* { lo, hi }
* { lower, upper }
*/
export function extractCI(
pooled: unknown,
): [number, number] | null {
if (typeof pooled !== "object" || pooled === null) return null;
const p = pooled as Record<string, unknown>;
const ci = p["ci"];
if (Array.isArray(ci) && ci.length === 2 && typeof ci[0] === "number" && typeof ci[1] === "number") {
return [ci[0], ci[1]];
}
const candidates: Array<[unknown, unknown]> = [
[p["ci_low"], p["ci_high"]],
[p["lo"], p["hi"]],
[p["lower"], p["upper"]],
];
for (const [lo, hi] of candidates) {
if (typeof lo === "number" && typeof hi === "number") return [lo, hi];
}
return null;
}
export function extractScore(pooled: unknown): number | null {
if (typeof pooled !== "object" || pooled === null) return null;
const p = pooled as Record<string, unknown>;
for (const key of ["value", "estimate", "mean", "score"]) {
const v = p[key];
if (typeof v === "number" && Number.isFinite(v)) return v;
}
return null;
}
export const causalAdapter: KnowledgeAdapter = {
kind: "causal_network",
status: "real",
async query(q: KnowledgeQuery): Promise<KnowledgeResult> {
enforceBudget(q.capabilityId, "causal_network");
const params = readParams(q.params ?? {});
const limit = Math.min(q.limit ?? 50, 500);
const fp = fingerprintParams({ ...params, limit });
const t0 = Date.now();
let error: string | null = null;
let items: KnowledgeItem[] = [];
try {
// 第 1 步:如果给了 bucket_key,先把它解析成 bucket_id 列表。
// jsonb 等值匹配 — drizzle 没有内置 jsonb_eq,用 sql 模板。
let resolvedBucketIds: string[] | undefined;
if (params.bucket_key) {
const keyJson = JSON.stringify(params.bucket_key);
const matched = await db
.select({ bucketId: literatureBuckets.bucketId })
.from(literatureBuckets)
.where(sql`${literatureBuckets.bucketKey} = ${keyJson}::jsonb`)
.limit(500);
resolvedBucketIds = matched.map((m) => m.bucketId);
if (resolvedBucketIds.length === 0) {
// bucket_key 没命中任何桶 — 直接返空(诚实标注 reason)
return {
kind: "causal_network",
items: [],
cursor: null,
sourceMetadata: {
primarySource: "pg:causal_edges",
limit,
reason: "bucket_key_no_match",
filters: params,
},
};
}
}
const conds = [];
if (params.bucket_id) conds.push(eq(causalEdges.bucketId, params.bucket_id));
const allBucketIds = [
...(params.bucket_ids ?? []),
...(resolvedBucketIds ?? []),
];
if (allBucketIds.length > 0) {
conds.push(inArray(causalEdges.bucketId, allBucketIds));
}
if (params.edge_type) conds.push(eq(causalEdges.equationType, params.edge_type));
if (params.record_status)
conds.push(eq(causalEdges.recordStatus, params.record_status));
if (params.effect_status)
conds.push(eq(causalEdges.effectStatus, params.effect_status));
if (params.paper_id) conds.push(eq(causalEdges.paperId, params.paper_id));
// pooled_estimate 数值区间 — extract 出 value (or estimate / mean / score)
// 然后 cast 数字比较。jsonb_path_query_first 太复杂,用 ->> 提取。
if (typeof params.score_min === "number") {
conds.push(
sql`(
COALESCE(
(${literatureBuckets.pooledEstimate}->>'value')::numeric,
(${literatureBuckets.pooledEstimate}->>'estimate')::numeric,
(${literatureBuckets.pooledEstimate}->>'mean')::numeric
) >= ${params.score_min}
)`,
);
}
if (typeof params.score_max === "number") {
conds.push(
sql`(
COALESCE(
(${literatureBuckets.pooledEstimate}->>'value')::numeric,
(${literatureBuckets.pooledEstimate}->>'estimate')::numeric,
(${literatureBuckets.pooledEstimate}->>'mean')::numeric
) <= ${params.score_max}
)`,
);
}
const whereExpr = conds.length > 0 ? and(...conds) : undefined;
const rows = await db
.select({
edgeId: causalEdges.edgeId,
paperId: causalEdges.paperId,
bucketId: causalEdges.bucketId,
equationType: causalEdges.equationType,
recordStatus: causalEdges.recordStatus,
effectStatus: causalEdges.effectStatus,
edge: causalEdges.edge,
pooledEstimate: literatureBuckets.pooledEstimate,
bucketKey: literatureBuckets.bucketKey,
})
.from(causalEdges)
.leftJoin(
literatureBuckets,
eq(causalEdges.bucketId, literatureBuckets.bucketId),
)
.where(whereExpr ?? sql`TRUE`)
.limit(limit);
// GROUP BY 一次拿每个 bucket 的 distinct paper_count
const bucketIdsForCount = Array.from(
new Set(rows.map((r) => r.bucketId).filter((x): x is string => !!x)),
);
const paperCountByBucket = new Map<string, number>();
if (bucketIdsForCount.length > 0) {
const counts = await db
.select({
bucketId: causalEdges.bucketId,
paperCount: sql<number>`COUNT(DISTINCT ${causalEdges.paperId})`.as(
"paper_count",
),
})
.from(causalEdges)
.where(inArray(causalEdges.bucketId, bucketIdsForCount))
.groupBy(causalEdges.bucketId);
for (const c of counts) {
if (c.bucketId)
paperCountByBucket.set(c.bucketId, Number(c.paperCount) || 0);
}
}
items = rows.map((r) => {
const ci = extractCI(r.pooledEstimate);
const score = extractScore(r.pooledEstimate);
return {
id: r.edgeId,
score,
payload: {
paper_id: r.paperId,
bucket_id: r.bucketId,
bucket_key: r.bucketKey,
equation_type: r.equationType,
record_status: r.recordStatus,
effect_status: r.effectStatus,
edge: r.edge,
pooled_estimate: r.pooledEstimate,
confidence_interval: ci,
paper_count: r.bucketId
? (paperCountByBucket.get(r.bucketId) ?? null)
: null,
},
origin: "pg:causal_edges",
};
});
} catch (err) {
error = err instanceof Error ? err.message : String(err);
throw err;
} finally {
void recordCall({
adapter: "causal_network",
capabilityId: q.capabilityId,
latencyMs: Date.now() - t0,
hitCount: items.length,
cacheHit: null,
error,
paramsFingerprint: fp,
});
}
return {
kind: "causal_network",
items,
cursor: null,
sourceMetadata: {
primarySource: "pg:causal_edges",
limit,
filters: params,
},
};
},
};