|
|
import supabase |
|
|
import gradio as gr |
|
|
from typing import Union, Optional |
|
|
import os |
|
|
from datetime import datetime |
|
|
import pytz |
|
|
from supabase_memory import SupabaseChatMessageHistory |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
supabase_client = supabase.create_client( |
|
|
os.environ.get("SUPABASE_URL"), |
|
|
os.environ.get("SUPABASE_KEY"), |
|
|
) |
|
|
|
|
|
def _get_user_id(request: gr.Request) -> str: |
|
|
response = supabase_client \ |
|
|
.table("users") \ |
|
|
.select("id") \ |
|
|
.eq("name", request.username) \ |
|
|
.execute() |
|
|
|
|
|
return response.data[0]["id"] |
|
|
|
|
|
def _delete_empty_sessions() -> None: |
|
|
|
|
|
supabase_client \ |
|
|
.table(os.environ["SESSIONS_TABLE_NAME"]) \ |
|
|
.delete() \ |
|
|
.is_("updated_at", "null") \ |
|
|
.execute() |
|
|
|
|
|
pass |
|
|
|
|
|
def _get_session_ids(user_id:str) -> list: |
|
|
|
|
|
_delete_empty_sessions() |
|
|
|
|
|
response = supabase_client \ |
|
|
.table(os.environ["SESSIONS_TABLE_NAME"]) \ |
|
|
.select("id") \ |
|
|
.eq("user_id", user_id) \ |
|
|
.order('created_at', desc=True) \ |
|
|
.execute() |
|
|
|
|
|
if len(response.data) == 0: |
|
|
session_id = _create_session(user_id) |
|
|
return [session_id] |
|
|
|
|
|
return [row["id"] for row in response.data] |
|
|
|
|
|
def _get_latest_message_id(session_id:str) -> str: |
|
|
|
|
|
response = supabase_client \ |
|
|
.table(os.environ["MESSAGES_TABLE_NAME"]) \ |
|
|
.select("id, score, message, error_log") \ |
|
|
.eq("chat_id", session_id) \ |
|
|
.is_("error_log", "null") \ |
|
|
.eq("message->>type", "ai") \ |
|
|
.order('created_at', desc=True) \ |
|
|
.limit(1) \ |
|
|
.execute() |
|
|
|
|
|
if len(response.data) == 0: |
|
|
return None, None |
|
|
|
|
|
return response.data[0]["id"], response.data[0]["score"] |
|
|
|
|
|
def _get_session_messages(session_id:str) -> list: |
|
|
|
|
|
memory = SupabaseChatMessageHistory( |
|
|
session_id = session_id, |
|
|
client = supabase_client, |
|
|
table_name = os.environ.get("MESSAGES_TABLE_NAME"), |
|
|
session_name = "chat", |
|
|
) |
|
|
|
|
|
messages = memory.messages |
|
|
|
|
|
return [(messages[i].content, messages[i+1].content) for i in range(0, len(messages), 2)] |
|
|
|
|
|
def _get_users() -> list: |
|
|
|
|
|
response = supabase_client \ |
|
|
.table("users") \ |
|
|
.select("name, password") \ |
|
|
.execute() |
|
|
|
|
|
return response.data |
|
|
|
|
|
def _update_session( |
|
|
session_id:str, |
|
|
metadata:Optional[dict] = None |
|
|
): |
|
|
|
|
|
update_dict = { |
|
|
"updated_at": datetime.now(pytz.utc).isoformat(), |
|
|
} |
|
|
|
|
|
if metadata is not None: |
|
|
update_dict["metadata"] = metadata |
|
|
|
|
|
supabase_client.table(os.environ["SESSIONS_TABLE_NAME"]) \ |
|
|
.update(update_dict) \ |
|
|
.eq('id', session_id) \ |
|
|
.execute() |
|
|
|
|
|
def _score_chosen( |
|
|
session_id:str, |
|
|
score:Optional[str] |
|
|
): |
|
|
|
|
|
print("score chosen...", score) |
|
|
allow_inputs = score is not None |
|
|
|
|
|
if allow_inputs: |
|
|
message_id, _ = _get_latest_message_id(session_id) |
|
|
|
|
|
response = supabase_client \ |
|
|
.table(os.environ.get("MESSAGES_TABLE_NAME")) \ |
|
|
.update({ |
|
|
"score": int(score), |
|
|
}) \ |
|
|
.eq('id', message_id) \ |
|
|
.execute() |
|
|
|
|
|
_update_session(response.data[0]["chat_id"]) |
|
|
|
|
|
return ( |
|
|
gr.Column(visible=score is not None), |
|
|
gr.Textbox(interactive=allow_inputs, placeholder = "Dai un voto alla risposta precedente prima di continuare la conversazione o iniziarne una nuova" if not allow_inputs else None), |
|
|
gr.Button(interactive=allow_inputs) |
|
|
) |
|
|
|
|
|
def _comment_submitted( |
|
|
session_id:str, |
|
|
comment:str |
|
|
): |
|
|
|
|
|
message_id, score = _get_latest_message_id(session_id) |
|
|
|
|
|
response = supabase_client \ |
|
|
.table(os.environ.get("MESSAGES_TABLE_NAME")) \ |
|
|
.update( |
|
|
{ |
|
|
"comment": comment |
|
|
} |
|
|
) \ |
|
|
.eq('id', message_id) \ |
|
|
.execute() |
|
|
|
|
|
_update_session(response.data[0]["chat_id"]) |
|
|
|
|
|
pass |
|
|
|
|
|
def _clear_comments(): |
|
|
return ( |
|
|
gr.Column(visible=False), |
|
|
None, |
|
|
gr.Radio(visible=False), |
|
|
gr.Textbox(interactive=True, value=None), |
|
|
gr.Button(interactive=True) |
|
|
) |
|
|
|
|
|
def _create_session( |
|
|
user_id:str, |
|
|
) -> str: |
|
|
|
|
|
response = supabase_client.table(os.environ["SESSIONS_TABLE_NAME"]) \ |
|
|
.insert( |
|
|
{ |
|
|
"user_id": user_id, |
|
|
} |
|
|
).execute() |
|
|
|
|
|
return response.data[0]["id"] |
|
|
|
|
|
def _new_chat(user_id:str): |
|
|
|
|
|
session_ids = _get_session_ids(user_id) |
|
|
session_id = _create_session(user_id) |
|
|
|
|
|
return ( |
|
|
[], |
|
|
gr.Textbox(visible=True, value=None), |
|
|
gr.Column(visible=False), |
|
|
None, |
|
|
gr.Radio(visible=False), |
|
|
gr.Dropdown(choices=[session_id] + session_ids, value=session_id, interactive=True), |
|
|
) |
|
|
|
|
|
def _get_session_metadata(session_id:str) -> dict: |
|
|
|
|
|
response = supabase_client \ |
|
|
.table(os.environ["SESSIONS_TABLE_NAME"]) \ |
|
|
.select("metadata") \ |
|
|
.eq("id", session_id) \ |
|
|
.execute() |
|
|
|
|
|
return response.data[0]["metadata"] |
|
|
|
|
|
def _session_id_selected(session_id): |
|
|
|
|
|
message_id, score = _get_latest_message_id(session_id) |
|
|
is_empty_session = message_id is None |
|
|
voted = score is not None |
|
|
|
|
|
allow_inputs = voted or is_empty_session |
|
|
|
|
|
print(f"session_id_selected..., allow_inputs: {allow_inputs}, voted: {voted}, is_empty_session: {is_empty_session}") |
|
|
|
|
|
return ( |
|
|
_get_session_messages(session_id), |
|
|
gr.Textbox( |
|
|
interactive=allow_inputs, |
|
|
value=None, |
|
|
placeholder = "Dai un voto alla risposta precedente prima di continuare la conversazione o iniziarne una nuova" if not allow_inputs else None), |
|
|
gr.Column(visible=False), |
|
|
None, |
|
|
gr.Radio(visible = not allow_inputs), |
|
|
) |
|
|
|
|
|
def _load_interface(request: gr.Request): |
|
|
|
|
|
user_id = _get_user_id(request) |
|
|
session_ids = _get_session_ids(user_id) |
|
|
|
|
|
print(f"loading interface...") |
|
|
|
|
|
return ( |
|
|
user_id, |
|
|
gr.Dropdown(choices = session_ids, value = session_ids[0]), |
|
|
) |
|
|
|
|
|
def _get_link(doc:dict) -> str: |
|
|
MAIN_URL = "https://def.finanze.it/DocTribFrontend" |
|
|
|
|
|
type = doc["type"] |
|
|
|
|
|
if type == "Prassi": |
|
|
row = supabase_client.table("praxis") \ |
|
|
.select("def_id") \ |
|
|
.eq("id", doc["supabase_praxis_id"]) \ |
|
|
.limit(1) \ |
|
|
.execute().data[0] |
|
|
|
|
|
return doc["title"], f"{MAIN_URL}/getPrassiDetail.do?id=%7B{row['def_id'].upper()}%7D" |
|
|
|
|
|
if type == "Dottrina": |
|
|
return doc["title"], None |
|
|
|
|
|
if type == "Norma": |
|
|
|
|
|
row = supabase_client.table("articles") \ |
|
|
.select("def_id, name, metadata, norms(def_id)") \ |
|
|
.eq("id", doc["supabase_id"]) \ |
|
|
.limit(1) \ |
|
|
.execute().data[0] |
|
|
|
|
|
params = { |
|
|
"ACTION": "getArticolo", |
|
|
"id": "{" + row["norms"]["def_id"].upper() + "}", |
|
|
"articolo": row["name"].replace(" ", "%20"), |
|
|
"codiceOrdinamento": row["metadata"]["codOrdinamento"], |
|
|
} |
|
|
|
|
|
return doc["title"], f"{MAIN_URL}/getAttoNormativoDetail.do?"+"&".join([f"{k}={v}" for k,v in params.items()]) |
|
|
|
|
|
|
|
|
def _add_footnote_description(answer:str, docs:list): |
|
|
""" |
|
|
For each markdown footnote placeholder "[^uuid]" in the answer, adds a description "[^uuid]: [title](link) of the footnote at the end of the answer. The title and link are retrieved from the docs list matching the uuid of the footnote using something like matching_doc = next( |
|
|
(doc for doc in docs if doc["supabase_id"] == uuid), None |
|
|
) |
|
|
""" |
|
|
import re |
|
|
footnotes = re.findall(r"\[[^\]]+\]", answer) |
|
|
|
|
|
for footnote in footnotes: |
|
|
matching_doc = next( |
|
|
(doc for doc in docs if doc["supabase_id"] == footnote.replace("[^", "").replace("]", "")), None |
|
|
) |
|
|
|
|
|
if matching_doc is not None: |
|
|
title, link = _get_link(matching_doc) |
|
|
if link is not None: |
|
|
answer += f"\n{footnote}: [{title}]({link})" |
|
|
else: |
|
|
answer += f"\n{footnote}: {title}" |
|
|
|
|
|
return answer |
|
|
|
|
|
def _replace_markdown_links(answer:str, docs:list): |
|
|
""" |
|
|
Replaces markdown link placeholders (text)[uuid] with actual links. |
|
|
|
|
|
Args: |
|
|
answer_dict: A dictionary containing the 'answer' with markdown links and 'docs' with relevant data. |
|
|
get_link_func: A function that takes a doc and returns the actual link. |
|
|
|
|
|
Returns: |
|
|
The modified answer string with replaced links. |
|
|
""" |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
link_pattern = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") |
|
|
|
|
|
def replace_link(match): |
|
|
text = match.group(1) |
|
|
uuid = match.group(2) |
|
|
|
|
|
|
|
|
matching_doc = next( |
|
|
(doc for doc in docs if doc["supabase_id"] == uuid), None |
|
|
) |
|
|
|
|
|
if matching_doc: |
|
|
link = _get_link(matching_doc) |
|
|
|
|
|
if link is not None: |
|
|
return f"[{text}]({link})" |
|
|
else: |
|
|
return text |
|
|
else: |
|
|
|
|
|
return match.group(0) |
|
|
|
|
|
|
|
|
modified_answer = link_pattern.sub(replace_link, answer) |
|
|
|
|
|
return modified_answer |
|
|
|
|
|
def _create_users(n=10): |
|
|
|
|
|
|
|
|
import random |
|
|
import string |
|
|
|
|
|
def random_string(length): |
|
|
return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) |
|
|
|
|
|
for i in range(n): |
|
|
supabase_client.table("users").insert({ |
|
|
"name": random_string(6), |
|
|
"password": random_string(12), |
|
|
}).execute() |
|
|
|
|
|
|