KN123's picture
Upload 3 files
a711d1c verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
from time import perf_counter
import json
from typing import List, Dict
import time
import datetime
import uvicorn
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
print("LLM using", device)
REMOTE_PATH = "KN123/nl2csv4instructions-TinyLlama-v2.0"
LOCAL_PATH = "nl2csv4instructions-TinyLlama-v2.0"
print("🟢 Fetching the models ....")
model = AutoModelForCausalLM.from_pretrained(REMOTE_PATH, device_map = device)
tokenizer = AutoTokenizer.from_pretrained(REMOTE_PATH)
print("🚀 Ready! nl2csv4instructions at your service!")
# def get_prompt(tables, question):
# prompt = f"""Made. tables: {tables}. question: {question}"""
# # print(prompt)
# return prompt
# def prepare_input(question: str, tables: Dict[str, List[str]]):
# tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
# # print(tables)
# tables = ", ".join(tables)
# # print(tables)
# prompt = get_prompt(tables, question)
# # print(prompt)
# input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
# # print(input_ids)
# return input_ids
# def inference(question: str, tables: Dict[str, List[str]]) -> str:
# input_data = prepare_input(question=question, tables=tables)
# input_data = input_data.to(model.device)
# outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
# # print("Outputs", outputs)
# result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
# return result
def parse(output):
# Find the index of '<|assistant|>'
end_tag_index = output.find('<|assistant|>')
extracted_text = ""
if end_tag_index != -1:
# Extract text after '<|assistant|>'
extracted_text = output[end_tag_index + len('<|assistant|>'):]
# Remove any leading '\n' characters
extracted_text = extracted_text.lstrip('\n')
#print(extracted_text)
else:
print("End tag '<|assistant|>' not found in output.")
return extracted_text
def formatted_prompt(question)-> str:
return f"<|user|>\n{question}</s>\n<|assistant|>"
def generate_response(user_input):
prompt = formatted_prompt(user_input)
inputs = tokenizer([prompt], return_tensors="pt")
generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True,
top_k=5,temperature=0.1,repetition_penalty=1.2,pad_token_id=tokenizer.eos_token_id,max_new_tokens=20,
)
start_time = perf_counter()
# inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
inputs = tokenizer(prompt, return_tensors="pt").to(device=device)
outputs = model.generate(**inputs, generation_config=generation_config)
llm_output = (tokenizer.decode(outputs[0], skip_special_tokens=True))
# print(llm_output)
output_time = perf_counter() - start_time
output_time = round(output_time,2)
# print(f"Time taken for inference: {output_time} seconds")
res = {}
res['llm_output'] = llm_output
res['time_taken'] = output_time
res['parsed_text'] = parse(llm_output)
return res
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
@app.get("/")
def home():
return {
"message" : "Hello there! Everything is working fine!",
"api-version": "2.0.0",
"role": "nl2csv4instructions",
"description": "This api can be used to convert natural language into instructions in the form of comma separate values, Ex: pick the items F-1222 and E-2222 or ask to reset all the settings."
}
@app.get("/test-generate")
def generate():
res = generate_response(user_input='set the color Blue to F-2244')
return res
@app.post("/generate")
def generate(request_body:str):
print("Request Got: ", request_body)
res = generate_response(user_input=request_body)
print(res)
return res
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=7860)