Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import GenerationConfig | |
| from time import perf_counter | |
| import json | |
| from typing import List, Dict | |
| import time | |
| import datetime | |
| import uvicorn | |
| import torch | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # device = 'cpu' | |
| print("LLM using", device) | |
| REMOTE_PATH = "KN123/nl2csv4instructions-TinyLlama-v2.0" | |
| LOCAL_PATH = "nl2csv4instructions-TinyLlama-v2.0" | |
| print("🟢 Fetching the models ....") | |
| model = AutoModelForCausalLM.from_pretrained(REMOTE_PATH, device_map = device) | |
| tokenizer = AutoTokenizer.from_pretrained(REMOTE_PATH) | |
| print("🚀 Ready! nl2csv4instructions at your service!") | |
| # def get_prompt(tables, question): | |
| # prompt = f"""Made. tables: {tables}. question: {question}""" | |
| # # print(prompt) | |
| # return prompt | |
| # def prepare_input(question: str, tables: Dict[str, List[str]]): | |
| # tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables] | |
| # # print(tables) | |
| # tables = ", ".join(tables) | |
| # # print(tables) | |
| # prompt = get_prompt(tables, question) | |
| # # print(prompt) | |
| # input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids | |
| # # print(input_ids) | |
| # return input_ids | |
| # def inference(question: str, tables: Dict[str, List[str]]) -> str: | |
| # input_data = prepare_input(question=question, tables=tables) | |
| # input_data = input_data.to(model.device) | |
| # outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512) | |
| # # print("Outputs", outputs) | |
| # result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) | |
| # return result | |
| def parse(output): | |
| # Find the index of '<|assistant|>' | |
| end_tag_index = output.find('<|assistant|>') | |
| extracted_text = "" | |
| if end_tag_index != -1: | |
| # Extract text after '<|assistant|>' | |
| extracted_text = output[end_tag_index + len('<|assistant|>'):] | |
| # Remove any leading '\n' characters | |
| extracted_text = extracted_text.lstrip('\n') | |
| #print(extracted_text) | |
| else: | |
| print("End tag '<|assistant|>' not found in output.") | |
| return extracted_text | |
| def formatted_prompt(question)-> str: | |
| return f"<|user|>\n{question}</s>\n<|assistant|>" | |
| def generate_response(user_input): | |
| prompt = formatted_prompt(user_input) | |
| inputs = tokenizer([prompt], return_tensors="pt") | |
| generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True, | |
| top_k=5,temperature=0.1,repetition_penalty=1.2,pad_token_id=tokenizer.eos_token_id,max_new_tokens=20, | |
| ) | |
| start_time = perf_counter() | |
| # inputs = tokenizer(prompt, return_tensors="pt").to('cuda') | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device=device) | |
| outputs = model.generate(**inputs, generation_config=generation_config) | |
| llm_output = (tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
| # print(llm_output) | |
| output_time = perf_counter() - start_time | |
| output_time = round(output_time,2) | |
| # print(f"Time taken for inference: {output_time} seconds") | |
| res = {} | |
| res['llm_output'] = llm_output | |
| res['time_taken'] = output_time | |
| res['parsed_text'] = parse(llm_output) | |
| return res | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| def home(): | |
| return { | |
| "message" : "Hello there! Everything is working fine!", | |
| "api-version": "2.0.0", | |
| "role": "nl2csv4instructions", | |
| "description": "This api can be used to convert natural language into instructions in the form of comma separate values, Ex: pick the items F-1222 and E-2222 or ask to reset all the settings." | |
| } | |
| def generate(): | |
| res = generate_response(user_input='set the color Blue to F-2244') | |
| return res | |
| def generate(request_body:str): | |
| print("Request Got: ", request_body) | |
| res = generate_response(user_input=request_body) | |
| print(res) | |
| return res | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="127.0.0.1", port=7860) |