|
|
from io import StringIO |
|
|
import streamlit as st |
|
|
import os |
|
|
from loguru import logger |
|
|
from client import OpenAIService, SQLService, FacebookLLAMAService, GoogleT5Service |
|
|
|
|
|
st.set_page_config( |
|
|
page_title=' A SQL Generative Pre-trained Transformer', |
|
|
layout='wide', |
|
|
initial_sidebar_state='expanded' |
|
|
) |
|
|
|
|
|
databases = ['Oracle', 'SQLServer', 'MySQL', 'DB2', 'PostgreSQL', 'Snowflake', 'Redshift'] |
|
|
|
|
|
models = ['openai-text-david-003', 'google-t5', 'facebook-llama'] |
|
|
|
|
|
|
|
|
st.sidebar.header('π A SQL Transformer for Migration') |
|
|
|
|
|
model = st.sidebar.selectbox(label='Model', options=models, index=0) |
|
|
openai_key = os.environ.get("OPEN-KEY", None) |
|
|
|
|
|
source_database = st.sidebar.selectbox( |
|
|
label='π Source Database', |
|
|
options=databases, |
|
|
index=0 |
|
|
) |
|
|
|
|
|
target_database = st.sidebar.selectbox( |
|
|
label='πTarget Database', |
|
|
options=databases, |
|
|
index=4 |
|
|
) |
|
|
|
|
|
input_text = st.sidebar.text_area( |
|
|
label='π Insert SQL', |
|
|
height=200, |
|
|
placeholder='select id from customer where rownum <= 100' |
|
|
) |
|
|
|
|
|
input_file = st.sidebar.file_uploader( |
|
|
label=" π Choose a SQL file", |
|
|
accept_multiple_files=False) |
|
|
|
|
|
|
|
|
def transform(): |
|
|
client: SQLService = None |
|
|
if model == "openai-text-david-003": |
|
|
client = OpenAIService(openai_key) |
|
|
elif model == "google-t5": |
|
|
client = GoogleT5Service() |
|
|
elif model == "facebook-llama": |
|
|
client = FacebookLLAMAService() |
|
|
logger.info(f"Using Model:{model}") |
|
|
|
|
|
code = input_text |
|
|
source = source_database |
|
|
target = target_database |
|
|
if code: |
|
|
solutions = client.translate(source, target, code) |
|
|
st.code(solutions[0], language='sql') |
|
|
else: |
|
|
if input_file is not None: |
|
|
|
|
|
stringio = StringIO(input_file.getvalue().decode("utf-8")) |
|
|
|
|
|
sql = stringio.read() |
|
|
solutions = client.translate(source, target, sql) |
|
|
st.code(solutions[0], language='sql') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform() |
|
|
|