File size: 4,415 Bytes
a711d1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
from time import perf_counter
import json

from typing import List, Dict
import time
import datetime
import uvicorn
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print("LLM using", device)

REMOTE_PATH = "KN123/nl2csv4instructions-TinyLlama-v2.0"
LOCAL_PATH = "nl2csv4instructions-TinyLlama-v2.0"
print("🟢 Fetching the models ....")
model = AutoModelForCausalLM.from_pretrained(REMOTE_PATH, device_map = device)
tokenizer = AutoTokenizer.from_pretrained(REMOTE_PATH)
print("🚀 Ready! nl2csv4instructions at your service!")


# def get_prompt(tables, question):
#     prompt = f"""Made. tables: {tables}. question: {question}"""
#     # print(prompt)
#     return prompt

# def prepare_input(question: str, tables: Dict[str, List[str]]):
#     tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
#     # print(tables)
#     tables = ", ".join(tables)
#     # print(tables)
#     prompt = get_prompt(tables, question)
#     # print(prompt)
#     input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
#     # print(input_ids)
#     return input_ids

# def inference(question: str, tables: Dict[str, List[str]]) -> str:
#     input_data = prepare_input(question=question, tables=tables)
#     input_data = input_data.to(model.device)
#     outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
#     # print("Outputs", outputs)
#     result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
#     return result

def parse(output):
  # Find the index of '<|assistant|>'
  end_tag_index = output.find('<|assistant|>')
  extracted_text = ""
  if end_tag_index != -1:
      # Extract text after '<|assistant|>'
      extracted_text = output[end_tag_index + len('<|assistant|>'):]
      # Remove any leading '\n' characters
      extracted_text = extracted_text.lstrip('\n')
      #print(extracted_text)
  else:
      print("End tag '<|assistant|>' not found in output.")

  return extracted_text

def formatted_prompt(question)-> str:
    return f"<|user|>\n{question}</s>\n<|assistant|>"

def generate_response(user_input):

  prompt = formatted_prompt(user_input)

  inputs = tokenizer([prompt], return_tensors="pt")
  generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True,
      top_k=5,temperature=0.1,repetition_penalty=1.2,pad_token_id=tokenizer.eos_token_id,max_new_tokens=20,
  )
  start_time = perf_counter()

  # inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
  inputs = tokenizer(prompt, return_tensors="pt").to(device=device)

  outputs = model.generate(**inputs, generation_config=generation_config)
  llm_output = (tokenizer.decode(outputs[0], skip_special_tokens=True))
  # print(llm_output)
  output_time = perf_counter() - start_time
  output_time = round(output_time,2)
#   print(f"Time taken for inference: {output_time} seconds")
  res = {}
  res['llm_output'] = llm_output
  res['time_taken'] = output_time
  res['parsed_text'] = parse(llm_output)
  return res

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

@app.get("/")
def home():
    return {
        "message" : "Hello there! Everything is working fine!",
        "api-version": "2.0.0",
        "role": "nl2csv4instructions",
        "description": "This api can be used to convert natural language into instructions in the form of comma separate values, Ex: pick the items F-1222 and E-2222 or ask to reset all the settings."
    }

@app.get("/test-generate")
def generate():
    res = generate_response(user_input='set the color Blue to F-2244')
    return res

@app.post("/generate")
def generate(request_body:str):
    print("Request Got: ", request_body)
    res = generate_response(user_input=request_body)
    print(res)
    return res



if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=7860)