File size: 2,748 Bytes
ee749be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""LLM-backed triple/entity extractor for PoC.

This module provides a small wrapper that asks the LLM (via LangChain ChatOpenAI)
to extract a small set of triples from a text chunk. It returns a list of dicts:
    {"subject": ..., "predicate": ..., "object": ..., "sentence": ..., "confidence": float}

The implementation is intentionally conservative and small for a Spaces-compatible PoC.
"""
from typing import List, Dict
import json

from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage


def extract_triples_with_llm(text: str, max_triples: int = 6, model_name: str = "gpt-3.5-turbo") -> List[Dict]:
    """Extract triples from text using a Chat LLM. Returns parsed JSON triples.

    Note: requires OPENAI_API_KEY in env for ChatOpenAI to work.
    """
    prompt = (
        "You are an assistant that extracts factual triples from a short text.\n"
        "Return a JSON array where each element is an object with keys: subject, predicate, object, sentence, confidence.\n"
        "Be concise and only return JSON. Confidence should be a float between 0.0 and 1.0.\n"
        f"Limit results to at most {max_triples} triples.\n\n"
        "Text:\n<<<TEXT_START>>>\n"
        + text
        + "\n<<<TEXT_END>>>\n"
    )

    # system message to instruct format strictly
    system = SystemMessage(content="You output only JSON arrays. Do not add any extra text.")
    human = HumanMessage(content=prompt)

    llm = ChatOpenAI(model_name=model_name, temperature=0.0)
    resp = llm([system, human])
    raw = resp.content.strip()

    # Attempt to find JSON in the output
    try:
        data = json.loads(raw)
    except Exception:
        # try to find first JSON substring
        start = raw.find("[")
        end = raw.rfind("]")
        if start != -1 and end != -1:
            try:
                data = json.loads(raw[start:end+1])
            except Exception:
                data = []
        else:
            data = []

    cleaned: List[Dict] = []
    for item in data:
        if not isinstance(item, dict):
            continue
        subj = item.get("subject") or item.get("s")
        pred = item.get("predicate") or item.get("p")
        obj = item.get("object") or item.get("o")
        sent = item.get("sentence") or ""
        conf = item.get("confidence")
        try:
            conf = float(conf) if conf is not None else 0.5
        except Exception:
            conf = 0.5
        if subj and pred and obj:
            cleaned.append({
                "subject": str(subj),
                "predicate": str(pred),
                "object": str(obj),
                "sentence": str(sent),
                "confidence": conf,
            })
    return cleaned