Spaces:
Runtime error
Runtime error
| import pickle | |
| import faiss | |
| from langchain import OpenAI | |
| from langchain.chains import VectorDBQAWithSourcesChain | |
| from zeno import ( | |
| ZenoOptions, | |
| distill, | |
| metric, | |
| model, | |
| ModelReturn, | |
| DistillReturn, | |
| MetricReturn, | |
| ) | |
| def get_model(model_name): | |
| # Blendle Notion chatbot example from: | |
| # https://github.com/hwchase17/chat-langchain-notion | |
| index = faiss.read_index("./docs.index") | |
| with open("./faiss_store.pkl", "rb") as f: | |
| store = pickle.load(f) | |
| store.index = index | |
| chain = VectorDBQAWithSourcesChain.from_llm( | |
| llm=OpenAI(temperature=0), vectorstore=store | |
| ) | |
| def pred(df, ops: ZenoOptions): | |
| res = [] | |
| for question in df[ops.data_column]: | |
| result = chain({"question": question}) | |
| res.append( | |
| "Answer: {}\nSources: {}".format(result["answer"], result["sources"]) | |
| ) | |
| return ModelReturn(model_output=res) | |
| return pred | |
| def correct(df, ops: ZenoOptions): | |
| return DistillReturn( | |
| distill_output=df.apply( | |
| lambda x: x[ops.label_column].lower() in x[ops.output_column].lower(), | |
| axis=1, | |
| ) | |
| ) | |
| def accuracy(df, ops: ZenoOptions): | |
| return MetricReturn(metric=df[ops.distill_columns["correct"]].astype(int).mean()) | |