qunfei commited on
Commit
589c127
Β·
1 Parent(s): daee5b3

adding all the changes

Browse files
Files changed (2) hide show
  1. app.py +17 -15
  2. client.py +64 -0
app.py CHANGED
@@ -1,14 +1,7 @@
1
  from io import StringIO
2
  import streamlit as st
3
  from loguru import logger
4
-
5
-
6
- class CoreClient:
7
-
8
- def generate(self, source: str, target: str, code: str):
9
- logger.info(f"translate {source} into {target}, with SQL {code}")
10
- return code
11
-
12
 
13
  st.set_page_config(
14
  page_title=' A SQL Generative Pre-trained Transformer',
@@ -19,7 +12,9 @@ st.set_page_config(
19
  databases = ['Oracle', 'SQLServer', 'MySQL', 'DB2', 'PostgreSQL', 'Snowflake', 'Redshift']
20
  # -------------
21
 
22
- st.sidebar.header('πŸŽƒ A SQL Transformer for Migration ')
 
 
23
 
24
  source_database = st.sidebar.selectbox(
25
  label='πŸ“• Source Database',
@@ -43,24 +38,31 @@ input_file = st.sidebar.file_uploader(
43
  label=" πŸ“„ Choose a SQL file",
44
  accept_multiple_files=False)
45
 
46
- client = CoreClient()
47
-
48
 
49
  def transform():
 
 
 
 
 
 
 
 
 
50
  code = input_text
51
  source = source_database
52
  target = target_database
53
  if code:
54
- solution = client.generate(source, target, code)
55
- st.code(solution, language='sql')
56
  else:
57
  if input_file is not None:
58
  # To convert to a string based IO:
59
  stringio = StringIO(input_file.getvalue().decode("utf-8"))
60
  # To read file as string:
61
  sql = stringio.read()
62
- solution = client.generate(source, target, sql)
63
- st.code(solution, language='sql')
64
 
65
 
66
  # ---------------------------------------
 
1
  from io import StringIO
2
  import streamlit as st
3
  from loguru import logger
4
+ from client import OpenAIService, SQLService, FacebookLLAMAService, GoogleT5Service
 
 
 
 
 
 
 
5
 
6
  st.set_page_config(
7
  page_title=' A SQL Generative Pre-trained Transformer',
 
12
  databases = ['Oracle', 'SQLServer', 'MySQL', 'DB2', 'PostgreSQL', 'Snowflake', 'Redshift']
13
  # -------------
14
 
15
+ st.sidebar.header('πŸŽƒ A SQL Transformer for Migration')
16
+
17
+ model = st.sidebar.radio('Model', ['openai-text-david-003', 'google-t5', 'facebook-llama'])
18
 
19
  source_database = st.sidebar.selectbox(
20
  label='πŸ“• Source Database',
 
38
  label=" πŸ“„ Choose a SQL file",
39
  accept_multiple_files=False)
40
 
 
 
41
 
42
  def transform():
43
+ client: SQLService = None
44
+ if model == "openai-text-david-003":
45
+ client = OpenAIService(st.secrets("OPEN-ORG"), st.secrets("OPEN-KEY"))
46
+ elif model == "google-t5":
47
+ client = GoogleT5Service()
48
+ elif model == "facebook-llama":
49
+ client = FacebookLLAMAService()
50
+ logger.info(f"Using Model:{model}")
51
+
52
  code = input_text
53
  source = source_database
54
  target = target_database
55
  if code:
56
+ solutions = client.translate(source, target, code)
57
+ st.code(solutions[0], language='sql')
58
  else:
59
  if input_file is not None:
60
  # To convert to a string based IO:
61
  stringio = StringIO(input_file.getvalue().decode("utf-8"))
62
  # To read file as string:
63
  sql = stringio.read()
64
+ solutions = client.translate(source, target, sql)
65
+ st.code(solutions[0], language='sql')
66
 
67
 
68
  # ---------------------------------------
client.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List
3
+ from loguru import logger
4
+ import openai
5
+
6
+
7
+ class SQLService(ABC):
8
+
9
+ @abstractmethod
10
+ def translate(self, source_db: str, target_db: str, sql: str) -> List:
11
+ ...
12
+
13
+
14
+ class OpenAIService(SQLService):
15
+
16
+ def __init__(self, organization, api_key) -> None:
17
+ super().__init__()
18
+ openai.organization = organization
19
+ openai.api_key = api_key
20
+
21
+ def translate(self, source_db: str, target_db: str, sql: str) -> List:
22
+ results = []
23
+ try:
24
+ response = openai.Completion.create(
25
+ model="text-davinci-003",
26
+ prompt=f"##### Translate this function from Oracle into Postgresql\n"
27
+ f"### {source_db}"
28
+ f""
29
+ f" {sql}"
30
+ f""
31
+ f"### {target_db}",
32
+ temperature=0,
33
+ max_tokens=2048,
34
+ top_p=1,
35
+ frequency_penalty=0,
36
+ presence_penalty=0,
37
+ stop=["###"]
38
+ )
39
+ for choice in response.choices:
40
+ logger.info(f"transform {source_db} to {target_db}, SQL:")
41
+ logger.debug(choice.text)
42
+ results.append(choice.text)
43
+ except Exception as ex:
44
+ logger.error(f"transform from {source_db} to {target_db}, failed \n {sql}")
45
+ logger.exception(ex)
46
+ return results
47
+
48
+
49
+ class GoogleT5Service(SQLService):
50
+
51
+ def treanslate(self, source_db: str, target_db: str, sql: str) -> List:
52
+ pass
53
+
54
+ def __init__(self):
55
+ ...
56
+
57
+
58
+ class FacebookLLAMAService(SQLService):
59
+
60
+ def treanslate(self, source_db: str, target_db: str, sql: str) -> List:
61
+ pass
62
+
63
+ def __init__(self):
64
+ ...