Spaces:
Build error
Build error
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.llms import CTransformers | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from fastapi import FastAPI, Request, Form, Response | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.encoders import jsonable_encoder | |
| from qdrant_client import QdrantClient | |
| from langchain.vectorstores import Qdrant | |
| import os | |
| import json | |
| app = FastAPI() | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| local_llm = "joshnader/meditron-7b-Q4_K_M-GGUF" | |
| config = { | |
| 'max_new_tokens': 512, | |
| 'context_length': 2048, | |
| 'repetition_penalty': 1.1, | |
| 'temperature': 0.1, | |
| 'top_k': 50, | |
| 'top_p': 0.9, | |
| 'stream': True, | |
| 'threads': int(os.cpu_count() / 4) | |
| } | |
| llm = CTransformers( | |
| model=local_llm, | |
| model_type="llama", | |
| **config | |
| ) | |
| print("LLM Initialized....") | |
| prompt_template = """Use the following pieces of information to answer the user's question. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| Context: {context} | |
| Question: {question} | |
| Only return the helpful answer below and nothing else. | |
| Helpful answer: | |
| """ | |
| embeddings = SentenceTransformerEmbeddings(model_name="NeuML/pubmedbert-base-embeddings") | |
| client = QdrantClient( | |
| url=os.getenv("QDRANT_URL", "https://QDRANT_URL.aws.cloud.qdrant.io"), | |
| api_key=os.getenv("QDRANT_API_KEY"), | |
| prefer_grpc=False | |
| ) | |
| db = Qdrant(client=client, embeddings=embeddings, collection_name="vector_db") | |
| prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) | |
| retriever = db.as_retriever(search_kwargs={"k":1}) | |
| async def read_root(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def get_response(query: str = Form(...)): | |
| chain_type_kwargs = {"prompt": prompt} | |
| qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs=chain_type_kwargs, verbose=True) | |
| response = qa(query) | |
| print(response) | |
| answer = response['result'] | |
| source_document = response['source_documents'][0].page_content | |
| doc = response['source_documents'][0].metadata['source'] | |
| response_data = jsonable_encoder(json.dumps({"answer": answer, "source_document": source_document, "doc": doc})) | |
| res = Response(response_data) | |
| return res |