Spaces:
Runtime error
Runtime error
| import os | |
| import pickle | |
| from json import dumps, loads | |
| from typing import Any, List, Mapping, Optional | |
| import numpy as np | |
| import openai | |
| import pandas as pd | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from huggingface_hub import HfFileSystem | |
| from langchain.llms.base import LLM | |
| from llama_index import ( | |
| Document, | |
| GPTVectorStoreIndex, | |
| LLMPredictor, | |
| PromptHelper, | |
| ServiceContext, | |
| SimpleDirectoryReader, | |
| StorageContext, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata | |
| from llama_index.prompts import Prompt | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from assets.prompts import custom_prompts | |
| load_dotenv() | |
| # openai.api_key = os.getenv("OPENAI_API_KEY") | |
| fs = HfFileSystem() | |
| # define prompt helper | |
| # set maximum input size | |
| CONTEXT_WINDOW = 2048 | |
| # set number of output tokens | |
| NUM_OUTPUT = 525 | |
| # set maximum chunk overlap | |
| CHUNK_OVERLAP_RATION = 0.2 | |
| text_qa_template = Prompt(custom_prompts.text_qa_template_str) | |
| refine_template = Prompt(custom_prompts.refine_template_str) | |
| def load_model(model_name: str): | |
| # llm_model_name = "bigscience/bloom-560m" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, config="T5Config") | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| # device=0, # GPU device number | |
| # max_length=512, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50, | |
| temperature=0.7, | |
| ) | |
| return pipe | |
| class OurLLM(CustomLLM): | |
| def __init__(self, model_name: str, model_pipeline): | |
| #self.model_name = model_name | |
| #self.pipeline = model_pipeline | |
| super().__init__(model_name=model_name) | |
| super().__init__(pipeline=model_pipeline) | |
| def metadata(self) -> LLMMetadata: | |
| """Get LLM metadata.""" | |
| return LLMMetadata( | |
| context_window=CONTEXT_WINDOW, | |
| num_output=NUM_OUTPUT, | |
| model_name=self.model_name, | |
| ) | |
| def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
| prompt_length = len(prompt) | |
| response = self.pipeline(prompt, max_new_tokens=NUM_OUTPUT)[0]["generated_text"] | |
| # only return newly generated tokens | |
| text = response[prompt_length:] | |
| return CompletionResponse(text=text) | |
| def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
| raise NotImplementedError() | |
| class LlamaCustom: | |
| def __init__(self, model_name: str) -> None: | |
| self.vector_index = self.initialize_index(model_name=model_name) | |
| def initialize_index(self, model_name: str): | |
| index_name = model_name.split("/")[-1] | |
| file_path = f"./vectorStores/{index_name}" | |
| if os.path.exists(path=file_path): | |
| # rebuild storage context | |
| storage_context = StorageContext.from_defaults(persist_dir=file_path) | |
| # local load index access | |
| index = load_index_from_storage(storage_context) | |
| # huggingface repo load access | |
| # with fs.open(file_path, "r") as file: | |
| # index = pickle.loads(file.readlines()) | |
| return index | |
| else: | |
| prompt_helper = PromptHelper( | |
| context_window=CONTEXT_WINDOW, | |
| num_output=NUM_OUTPUT, | |
| chunk_overlap_ratio=CHUNK_OVERLAP_RATION, | |
| ) | |
| # define llm | |
| pipe = load_model(model_name=model_name) | |
| llm = OurLLM(model_name=model_name, model_pipeline=pipe) | |
| llm_predictor = LLMPredictor(llm=llm) | |
| service_context = ServiceContext.from_defaults( | |
| llm_predictor=llm_predictor, prompt_helper=prompt_helper | |
| ) | |
| # documents = prepare_data(r"./assets/regItems.json") | |
| documents = SimpleDirectoryReader(input_dir="./assets/pdf").load_data() | |
| index = GPTVectorStoreIndex.from_documents( | |
| documents, service_context=service_context | |
| ) | |
| # local write access | |
| index.storage_context.persist(file_path) | |
| # huggingface repo write access | |
| # with fs.open(file_path, "w") as file: | |
| # file.write(pickle.dumps(index)) | |
| return index | |
| def get_response(self, query_str): | |
| print("query_str: ", query_str) | |
| # query_engine = self.vector_index.as_query_engine() | |
| query_engine = self.vector_index.as_query_engine( | |
| text_qa_template=text_qa_template, refine_template=refine_template | |
| ) | |
| response = query_engine.query(query_str) | |
| print("metadata: ", response.metadata) | |
| return str(response) | |