KN123 commited on
Commit
8a80df9
·
verified ·
1 Parent(s): fb80f11

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -70
main.py DELETED
@@ -1,70 +0,0 @@
1
- from fastapi import FastAPI
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- from typing import List, Dict
4
- import time
5
- import datetime
6
- import uvicorn
7
-
8
- model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
9
- tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")
10
-
11
- def get_prompt(tables, question):
12
- prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
13
- # print(prompt)
14
- return prompt
15
-
16
- def prepare_input(question: str, tables: Dict[str, List[str]]):
17
- tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
18
- # print(tables)
19
- tables = ", ".join(tables)
20
- # print(tables)
21
- prompt = get_prompt(tables, question)
22
- # print(prompt)
23
- input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
24
- # print(input_ids)
25
- return input_ids
26
-
27
- def inference(question: str, tables: Dict[str, List[str]]) -> str:
28
- input_data = prepare_input(question=question, tables=tables)
29
- input_data = input_data.to(model.device)
30
- outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
31
- # print("Outputs", outputs)
32
- result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
33
- return result
34
-
35
- app = FastAPI()
36
-
37
- @app.get("/")
38
- def home():
39
- return {
40
- "message" : "Hello there! Everything is working fine!",
41
- "api-version": "1.0.0",
42
- "role": "nl2sql",
43
- "description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
44
- }
45
-
46
- @app.get("/generate")
47
- def generate(text:str):
48
- start = time.time()
49
- res = inference("how many people with name jui and age less than 25", {
50
- "people_name":["id","name"], "people_age": ["people_id","age"]
51
- })
52
- end = time.time()
53
- total_time_taken = end - start
54
- current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
55
- current_date = datetime.date.today()
56
- timezone_name = time.tzname[time.daylight]
57
- print(res)
58
- return {
59
- "api_response": f"{res}",
60
- "time_taken(s)": f"{total_time_taken}",
61
- "request_details": {
62
- "utc_datetime": f"{current_utc_datetime}",
63
- "current_date": f"{current_date}",
64
- "timezone_name": f"{timezone_name}"
65
- }
66
- }
67
-
68
-
69
- if __name__ == "__main__":
70
- uvicorn.run(app, host="127.0.0.1", port=8000)