Spaces:
Runtime error
Runtime error
| # main.py | |
| import logging | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import nest_asyncio | |
| from pyngrok import ngrok | |
| import uvicorn | |
| import json | |
| from model import Model | |
| from doc_reader import DocReader | |
| from transformers import GenerationConfig, pipeline | |
| from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from langchain.schema.runnable import RunnableBranch | |
| from langchain_core.runnables import RunnableLambda | |
| import torch | |
| # Logger configuration | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S') | |
| logger = logging.getLogger(__name__) | |
| import os | |
| os.system("nvidia-smi") | |
| print("TORCH_CUDA", torch.cuda.is_available()) | |
| # Add path to sys | |
| # sys.path.insert(0,'/opt/accelerate') | |
| # sys.path.insert(0,'/opt/uvicorn') | |
| # sys.path.insert(0,'/opt/pyngrok') | |
| # sys.path.insert(0,'/opt/huggingface_hub') | |
| # sys.path.insert(0,'/opt/nest_asyncio') | |
| # sys.path.insert(0,'/opt/transformers') | |
| # sys.path.insert(0,'/opt/pytorch') | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| #NGROK_TOKEN = "2aQUM6MDkhjcPEBbIFTiu4cZBBr_sMMei8h5yejFbxFeMFuQ" # Replace with your NGROK token | |
| #MODEL_NAME = "/opt/Llama-2-13B-chat-GPTQ" | |
| #MODEL_NAME = "MediaTek-Research/Breeze-7B-Instruct-64k-v0.1" | |
| MODEL_NAME = "codellama/CodeLlama-7b-Instruct-hf" | |
| PDF_PATH = "/opt/docs" | |
| CLASSIFIER_MODEL_NAME = "roberta-large-mnli" | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=['*'], | |
| allow_headers=['*'], | |
| ) | |
| model_instance = Model(MODEL_NAME) | |
| model_instance.load() | |
| #model_instance.load(model_name_or_path = GGUF_HUGGINGFACE_REPO, model_basename = GGUF_HUGGINGFACE_BIN_FILE | |
| # classifier_model = pipeline("zero-shot-classification", | |
| # model=CLASSIFIER_MODEL_NAME) | |
| async def predict_text(request: Request): | |
| try: | |
| # Parse request body as JSON | |
| request_body = await request.json() | |
| prompt = request_body.get("prompt", "") | |
| # TODO: handle additional parameters like 'temperature' or 'max_tokens' if needed | |
| result = general_chain.invoke({"question":prompt}) | |
| logger.info(f"Result: {result}") | |
| formatted_response = { | |
| "choices": [ | |
| { | |
| "message": { | |
| "content": result['result'] | |
| } | |
| } | |
| ] | |
| } | |
| return formatted_response | |
| except json.JSONDecodeError: | |
| return {"error": "Invalid JSON format"} | |
| def load_pdfs(): | |
| global db | |
| doc_reader = DocReader(PDF_PATH) | |
| # Load PDFs and convert to Markdown | |
| pages = doc_reader.load_pdfs() | |
| markdown_text = doc_reader.convert_to_markdown(pages) | |
| texts = doc_reader.split_text([markdown_text]) # Assuming split_text now takes a list of Markdown texts | |
| # Generate embeddings | |
| db = doc_reader.generate_embeddings(texts) | |
| # def classify_sequence(input_data): | |
| # sequence_to_classify = input_data["question"] | |
| # candidate_labels = ['LinuxCommand', 'TechnicalSupport', 'GeneralResponse'] | |
| # classification = classifier_model(sequence_to_classify, candidate_labels) | |
| # # Extract the label with the highest score | |
| # return {"topic": classification['labels'][0], "question": sequence_to_classify} | |
| def format_output(output): | |
| return {"result": output} | |
| def setup_chain(): | |
| #global full_chain | |
| #global classifier_chain | |
| global command_chain | |
| #global support_chain | |
| global general_chain | |
| generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
| generation_config.max_new_tokens = 1024 | |
| generation_config.temperature = 0.3 | |
| generation_config.top_p = 0.9 | |
| generation_config.do_sample = True | |
| generation_config.repetition_penalty = 1.15 | |
| text_pipeline = pipeline( | |
| "text-generation", | |
| model=model_instance.model, | |
| tokenizer=model_instance.tokenizer, | |
| return_full_text=True, | |
| generation_config=generation_config, | |
| ) | |
| llm = HuggingFacePipeline(pipeline=text_pipeline) | |
| # Classifier | |
| #classifier_runnable = RunnableLambda(classify_sequence) | |
| # Formatter | |
| output_runnable = RunnableLambda(format_output) | |
| # System Commands | |
| command_template = """ | |
| [INST] <<SYS>> | |
| As a Gemini Central engineer specializing in Linux, evaluate the user's input and choose the most likely command they want to execute from these options: | |
| - 'systemctl stop sbox-admin' | |
| - 'systemctl start sbox-admin' | |
| - 'systemctl restart sbox-admin' | |
| Respond with the chosen command. If uncertain, reply with 'No command will be executed'. | |
| <</SYS>> | |
| question: | |
| {question} | |
| answer: | |
| [/INST]""" | |
| command_chain = (PromptTemplate(template=command_template,input_variables=["question"]) | llm | output_runnable ) | |
| # Support | |
| # support_template = """ | |
| # [INST] <<SYS>> | |
| # Act as a Gemini support engineer who is good at reading technical data. Use the following information to answer the question at the end. | |
| # <</SYS>> | |
| # {context} | |
| # {question} | |
| # answer: | |
| # [/INST] | |
| # """ | |
| # General | |
| general_template = """ | |
| [INST] <<SYS>> | |
| You are an advanced AI assistant designed to provide assistance with a wide range of queries. | |
| Users may request you to assume various roles or perform diverse tasks | |
| <</SYS>> | |
| question: | |
| {question} | |
| answer: | |
| [/INST]""" | |
| general_chain = (PromptTemplate(template=general_template,input_variables=["question"]) | llm | output_runnable) | |
| #support_prompt = PromptTemplate(template=support_template, input_variables=["context","question"]) | |
| #support_chain = RetrievalQA.from_llm(llm=llm, retriever= db.as_retriever(), prompt=support_prompt, input_key="question", return_source_documents=True, verbose=True) | |
| # support_chain = RetrievalQA.from_chain_type( | |
| # llm=llm, | |
| # chain_type="stuff", | |
| # #retriever=db.as_retriever(search_kwargs={"k": 3}), | |
| # retriever=db.as_retriever(), | |
| # input_key="question", | |
| # return_source_documents=True, | |
| # chain_type_kwargs={"prompt": support_prompt}, | |
| # verbose=False | |
| # ) | |
| # logger.info("support chain loaded successfully.") | |
| # branch = RunnableBranch( | |
| # (lambda x: x == "command", command_chain), | |
| # (lambda x: x == "support", support_chain), | |
| # general_chain, # Default chain | |
| # ) | |
| # def route_classification(output): | |
| # if output['topic'] == 'LinuxCommand': | |
| # logger.info("Routing to command chain") | |
| # return command_chain | |
| # elif output['topic'] == 'TechnicalSupport': | |
| # logger.info("Routing to support chain") | |
| # return support_chain | |
| # else: | |
| # logger.info("Routing to general chain") | |
| # return general_chain | |
| # routing_runnable = RunnableLambda(route_classification) | |
| # Full chain integration | |
| #full_chain = classifier_runnable | routing_runnable | |
| #logger.info("Full chain loaded successfully.") | |
| return general_chain | |
| ############### | |
| # launch once at startup | |
| #load_pdfs() | |
| setup_chain() | |
| ############### | |
| #if __name__ == "__main__": | |
| # if NGROK_TOKEN is not None: | |
| # ngrok.set_auth_token(NGROK_TOKEN) | |
| # ngrok_tunnel = ngrok.connect(8000) | |
| # public_url = ngrok_tunnel.public_url | |
| # print('Public URL:', public_url) | |
| # print("You can use {}/predict to get the assistant result.".format(public_url)) | |
| # logger.info("You can use {}/predict to get the assistant result.".format(public_url)) | |
| #nest_asyncio.apply() | |
| #uvicorn.run(app, port=8000) | |