neovalle commited on
Commit
8b6a799
·
verified ·
1 Parent(s): 1b8c559

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -0
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (paste into app.py)
2
+ import os
3
+ import json
4
+ from functools import lru_cache
5
+ from typing import List, Tuple, Optional, Any, Dict
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from rdflib import Graph, URIRef, Literal
12
+ from rdflib.namespace import RDFS, RDF, SKOS, DCTERMS
13
+
14
+ ONTOLOGY_PATH = os.getenv("ONTOLOGY_PATH", "narratives.ttl")
15
+ DEFAULT_SEARCH_METHOD = os.getenv("SEARCH_METHOD", "keyword")
16
+ DEFAULT_STYLE = os.getenv("PROMPT_STYLE", "balanced")
17
+ TOP_K_CONCEPTS = int(os.getenv("TOP_K_CONCEPTS", "8"))
18
+ EXPANSION_DEPTH = int(os.getenv("EXPANSION_DEPTH", "1"))
19
+ INCLUDE_RELATIONS = os.getenv("INCLUDE_RELATIONS", "true").lower() == "true"
20
+
21
+ COHERE_EMBED_MODEL = os.getenv("COHERE_EMBED_MODEL", "embed-english-v3.0")
22
+ COHERE_CHAT_MODEL = os.getenv("COHERE_CHAT_MODEL", "command-r")
23
+
24
+ class OntologyEntry:
25
+ def __init__(self, uri: str, labels: List[str], alt_labels: List[str], description: str, types: List[str]):
26
+ self.uri = uri; self.labels = labels; self.alt_labels = alt_labels
27
+ self.description = description; self.types = types
28
+
29
+ def _lit2text(lit: Any) -> Optional[str]:
30
+ if isinstance(lit, Literal): return str(lit)
31
+ if isinstance(lit, str): return lit
32
+ return None
33
+
34
+ @lru_cache(maxsize=1)
35
+ def _load_graph(path: str) -> Graph:
36
+ g = Graph()
37
+ if not os.path.exists(path):
38
+ raise FileNotFoundError(f"Ontology file not found at: {path}")
39
+ g.parse(path, format="turtle")
40
+ return g
41
+
42
+ @lru_cache(maxsize=1)
43
+ def _index_entries(path: str) -> List[OntologyEntry]:
44
+ g = _load_graph(path)
45
+ entries: Dict[str, OntologyEntry] = {}
46
+ for s in set(g.subjects()):
47
+ uri = str(s)
48
+ labels, alt_labels, desc = set(), set(), []
49
+ for _,_,o in g.triples((s, RDFS.label, None)):
50
+ t=_lit2text(o); labels.add(t) if t else None
51
+ for _,_,o in g.triples((s, SKOS.prefLabel, None)):
52
+ t=_lit2text(o); labels.add(t) if t else None
53
+ for _,_,o in g.triples((s, SKOS.altLabel, None)):
54
+ t=_lit2text(o); alt_labels.add(t) if t else None
55
+ for p in [RDFS.comment, SKOS.definition, DCTERMS.description]:
56
+ for _,_,o in g.triples((s, p, None)):
57
+ t=_lit2text(o); desc.append(t) if t else None
58
+ types = [str(o) for _,_,o in g.triples((s, RDF.type, None)) if isinstance(o, URIRef)]
59
+ if labels or alt_labels or desc:
60
+ entries[uri] = OntologyEntry(uri, sorted(labels), sorted(alt_labels), " ".join(desc), types)
61
+ return list(entries.values())
62
+
63
+ def _neighbors(g: Graph, node: URIRef) -> List[URIRef]:
64
+ neigh = set()
65
+ for p in [SKOS.broader, SKOS.narrower, SKOS.related, RDFS.seeAlso]:
66
+ for _,_,o in g.triples((node, p, None)):
67
+ if isinstance(o, URIRef): neigh.add(o)
68
+ for s,_,_ in g.triples((None, p, node)):
69
+ if isinstance(s, URIRef): neigh.add(s)
70
+ return list(neigh)
71
+
72
+ def expand_concepts(path: str, seeds: List[OntologyEntry], depth: int = 1) -> List[OntologyEntry]:
73
+ if depth <= 0 or not seeds: return seeds
74
+ g = _load_graph(path)
75
+ idx = {e.uri: e for e in _index_entries(path)}
76
+ frontier = [URIRef(s.uri) for s in seeds]
77
+ visited=set(frontier); collected=set([s.uri for s in seeds])
78
+ for _ in range(depth):
79
+ nxt=[]
80
+ for n in frontier:
81
+ for nb in _neighbors(g, n):
82
+ if nb not in visited:
83
+ visited.add(nb); nxt.append(nb); collected.add(str(nb))
84
+ frontier=nxt
85
+ return [idx[u] for u in collected if u in idx]
86
+
87
+ def _normalise_text(s: str) -> str:
88
+ return " ".join(s.lower().strip().split())
89
+
90
+ def keyword_scores(prompt: str, entries: List[OntologyEntry]) -> List[Tuple[OntologyEntry, float]]:
91
+ p = _normalise_text(prompt); scored=[]
92
+ for e in entries:
93
+ best=0.0
94
+ for t in e.labels + e.alt_labels + ([e.description] if e.description else []):
95
+ if not t: continue
96
+ tnorm=_normalise_text(t)
97
+ overlap=sum(1 for tok in p.split() if tok in tnorm.split())
98
+ score=100.0*overlap/max(1,len(p.split()))
99
+ best=max(best,score)
100
+ if best>0: scored.append((e,best))
101
+ scored.sort(key=lambda x:x[1], reverse=True); return scored
102
+
103
+ def _get_cohere_client():
104
+ api_key=os.getenv("COHERE_API_KEY")
105
+ if not api_key: raise RuntimeError("COHERE_API_KEY not set")
106
+ import cohere; return cohere.Client(api_key)
107
+
108
+ def _normalize_cohere_chat_model(name: Optional[str]) -> str:
109
+ if not name: return "command-r"
110
+ n=name.strip().lower()
111
+ if n in {"r","commandr","command_r"}: return "command-r"
112
+ if n in {"r+","r-plus","commandr+","commandr-plus","command_r_plus"}: return "command-r-plus"
113
+ return name
114
+
115
+ def embed_texts(texts: List[str], model: str) -> np.ndarray:
116
+ client=_get_cohere_client()
117
+ res=client.embed(texts=texts, model=model)
118
+ vecs=getattr(res,"embeddings",None) or (res.get("embeddings") if isinstance(res,dict) else None)
119
+ if vecs is None: raise RuntimeError("Unexpected response from Cohere embed()")
120
+ return np.array(vecs,dtype=float)
121
+
122
+ def cosine_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
123
+ a=a.astype(float); b=b.astype(float)
124
+ a/= (np.linalg.norm(a,axis=1,keepdims=True)+1e-12)
125
+ b/= (np.linalg.norm(b,axis=1,keepdims=True)+1e-12)
126
+ return a @ b.T
127
+
128
+ def embedding_scores(prompt: str, entries: List[OntologyEntry], embed_model: str) -> List[Tuple[OntologyEntry, float]]:
129
+ labels=[]; idx=[]
130
+ for i,e in enumerate(entries):
131
+ text=" | ".join(e.labels+e.alt_labels)
132
+ if e.description: text += " | "+e.description[:300]
133
+ labels.append(text if text else e.uri); idx.append(i)
134
+ if not labels: return []
135
+ vp=embed_texts([prompt], embed_model)
136
+ vl=embed_texts(labels, embed_model)
137
+ sims=cosine_sim(vp, vl)[0]
138
+ scored=[(entries[i],float(s)) for i,s in zip(idx,sims)]
139
+ scored.sort(key=lambda x:x[1], reverse=True); return scored
140
+
141
+ def make_enhanced_prompt(original_prompt: str, matches: List[Tuple[OntologyEntry, float]], style: str = "balanced") -> str:
142
+ if style=="minimal":
143
+ lines=[original_prompt.strip()]
144
+ if matches:
145
+ lines.append("\\nConsider these related concepts:")
146
+ for e,_ in matches[:8]:
147
+ label=e.labels[0] if e.labels else e.uri.rsplit("/",1)[-1]
148
+ lines.append(f"- {label}")
149
+ return "\\n".join(lines)
150
+ lines=[
151
+ "You are to answer the user succinctly and accurately.",
152
+ "First, consider these ontology cues to interpret the request more broadly; then answer plainly.\\n",
153
+ "User request:",
154
+ f"\"\"\"{original_prompt.strip()}\"\"\"\\n"
155
+ ]
156
+ if matches:
157
+ lines.append("Ontology cues possibly relevant:")
158
+ for e,_ in matches[:10]:
159
+ label=e.labels[0] if e.labels else e.uri.rsplit("/",1)[-1]
160
+ note=e.description[:140].strip().replace("\\n"," ") if e.description else ""
161
+ if note and not note.endswith("."): note+="."
162
+ lines.append(f"- {label} — {note}")
163
+ else:
164
+ lines.append("No strong ontology matches were found; proceed with general best practices.")
165
+ lines += [
166
+ "\\nWhen responding, please:",
167
+ "- Make assumptions explicit; surface trade-offs and impacts if relevant.",
168
+ "- Use precise terms; avoid vague growth/technological-fix framings unless justified.",
169
+ "- If uncertain, state limits and what evidence would resolve them.",
170
+ "\\nNow provide your answer:"
171
+ ]
172
+ return "\\n".join(lines)
173
+
174
+ def _find_matches(user_prompt: str, method: str, top_k: int, expansion_depth: int) -> List[Tuple[OntologyEntry, float]]:
175
+ entries=_index_entries(ONTOLOGY_PATH)
176
+ if method=="embedding":
177
+ try: base=embedding_scores(user_prompt, entries, COHERE_EMBED_MODEL)[:top_k]
178
+ except Exception: base=keyword_scores(user_prompt, entries)[:top_k]
179
+ else:
180
+ base=keyword_scores(user_prompt, entries)[:top_k]
181
+ if expansion_depth>0 and base:
182
+ expanded=expand_concepts(ONTOLOGY_PATH, [b[0] for b in base], depth=expansion_depth)
183
+ rescored=keyword_scores(user_prompt, expanded)
184
+ by_uri={}
185
+ for e,sc in base+rescored:
186
+ if (e.uri not in by_uri) or (sc>by_uri[e.uri][1]): by_uri[e.uri]=(e,sc)
187
+ return sorted(by_uri.values(), key=lambda x:x[1], reverse=True)[:top_k]
188
+ return base
189
+
190
+ def _llm_chat(prompt: str, model: Optional[str] = None, temperature: float = 0.2) -> str:
191
+ mdl=_normalize_cohere_chat_model(model or COHERE_CHAT_MODEL)
192
+ try:
193
+ client=_get_cohere_client()
194
+ except Exception as e:
195
+ return f"[LLM disabled: {e}]"
196
+ try:
197
+ res=client.chat(model=mdl, message=prompt, temperature=temperature)
198
+ return getattr(res, "text", str(res))
199
+ except Exception as e:
200
+ if "not found" in str(e).lower() and mdl!="command-r":
201
+ try:
202
+ res=client.chat(model="command-r", message=prompt, temperature=temperature)
203
+ return getattr(res, "text", str(res))
204
+ except Exception as e2:
205
+ return f"[LLM error after fallback: {e2}]"
206
+ return f"[LLM error: {e}]"
207
+
208
+ def enhance_prompt_tool(user_prompt: str,
209
+ search_method: str = DEFAULT_SEARCH_METHOD,
210
+ top_k: int = TOP_K_CONCEPTS,
211
+ expansion_depth: int = EXPANSION_DEPTH,
212
+ style: str = DEFAULT_STYLE,
213
+ call_llm: bool = False,
214
+ temperature: float = 0.2,
215
+ chat_model: Optional[str] = None):
216
+ matches=_find_matches(user_prompt, method=search_method, top_k=top_k, expansion_depth=expansion_depth)
217
+ enhanced=make_enhanced_prompt(user_prompt, matches, style=style)
218
+ out={
219
+ "original_prompt": user_prompt,
220
+ "enhanced_prompt": enhanced,
221
+ "matches":[{"uri":e.uri,"label":(e.labels[0] if e.labels else e.uri.rsplit("/",1)[-1]),"score":score}
222
+ for e,score in matches]
223
+ }
224
+ if call_llm:
225
+ out["original_reply"]=_llm_chat(user_prompt, model=chat_model, temperature=temperature)
226
+ out["enhanced_reply"]=_llm_chat(enhanced, model=chat_model, temperature=temperature)
227
+ return out
228
+
229
+ with gr.Blocks(title="Ontology Prompt Enhancer (MCP)") as demo:
230
+ gr.Markdown("# Ontology Prompt Enhancer (MCP)")
231
+ with gr.Row():
232
+ p = gr.Textbox(label="Your prompt")
233
+ with gr.Row():
234
+ m = gr.Radio(choices=["keyword","embedding"], value=DEFAULT_SEARCH_METHOD, label="Search method")
235
+ st = gr.Radio(choices=["minimal","balanced","verbose"], value=DEFAULT_STYLE, label="Prompt style")
236
+ with gr.Row():
237
+ k = gr.Slider(1, 20, value=TOP_K_CONCEPTS, step=1, label="Top-K concepts")
238
+ d = gr.Slider(0, 3, value=EXPANSION_DEPTH, step=1, label="Expansion depth")
239
+ with gr.Row():
240
+ call = gr.Checkbox(False, label="Also call LLM (Cohere)")
241
+ temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
242
+ model = gr.Textbox(value=COHERE_CHAT_MODEL, label="Cohere chat model")
243
+ out = gr.JSON(label="Result")
244
+ gr.Button("Enhance").click(fn=enhance_prompt_tool, inputs=[p,m,k,d,st,call,temp,model], outputs=out)
245
+
246
+ if __name__ == "__main__":
247
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), mcp_server=True)