Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_community.graphs import Neo4jGraph | |
| from dotenv import load_dotenv | |
| from utils import ( | |
| create_vector_index, | |
| BaseLogger, | |
| ) | |
| from chains import ( | |
| load_embedding_model, | |
| load_llm, | |
| configure_llm_only_chain, | |
| configure_qa_rag_chain, | |
| generate_ticket, | |
| ) | |
| from fastapi import FastAPI, Depends | |
| from pydantic import BaseModel | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from threading import Thread | |
| from queue import Queue, Empty | |
| from collections.abc import Generator | |
| from sse_starlette.sse import EventSourceResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import json | |
| load_dotenv(".env") | |
| url = os.getenv("NEO4J_URI") | |
| username = os.getenv("NEO4J_USERNAME") | |
| password = os.getenv("NEO4J_PASSWORD") | |
| ollama_base_url = os.getenv("OLLAMA_BASE_URL") | |
| embedding_model_name = os.getenv("EMBEDDING_MODEL") | |
| llm_name = os.getenv("LLM") | |
| # Remapping for Langchain Neo4j integration | |
| os.environ["NEO4J_URL"] = url | |
| embeddings, dimension = load_embedding_model( | |
| embedding_model_name, | |
| config={"ollama_base_url": ollama_base_url}, | |
| logger=BaseLogger(), | |
| ) | |
| # if Neo4j is local, you can go to http://localhost:7474/ to browse the database | |
| neo4j_graph = Neo4jGraph( | |
| url=url, username=username, password=password, refresh_schema=False | |
| ) | |
| create_vector_index(neo4j_graph) | |
| llm = load_llm( | |
| llm_name, logger=BaseLogger(), config={"ollama_base_url": ollama_base_url} | |
| ) | |
| llm_chain = configure_llm_only_chain(llm) | |
| rag_chain = configure_qa_rag_chain( | |
| llm, embeddings, embeddings_store_url=url, username=username, password=password | |
| ) | |
| class QueueCallback(BaseCallbackHandler): | |
| """Callback handler for streaming LLM responses to a queue.""" | |
| def __init__(self, q): | |
| self.q = q | |
| def on_llm_new_token(self, token: str, **kwargs) -> None: | |
| self.q.put(token) | |
| def on_llm_end(self, *args, **kwargs) -> None: | |
| return self.q.empty() | |
| def stream(cb, q) -> Generator: | |
| job_done = object() | |
| def task(): | |
| x = cb() | |
| q.put(job_done) | |
| t = Thread(target=task) | |
| t.start() | |
| content = "" | |
| # Get each new token from the queue and yield for our generator | |
| while True: | |
| try: | |
| next_token = q.get(True, timeout=1) | |
| if next_token is job_done: | |
| break | |
| content += next_token | |
| yield next_token, content | |
| except Empty: | |
| continue | |
| app = FastAPI() | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| return {"message": "Hello World"} | |
| class Question(BaseModel): | |
| text: str | |
| rag: bool = False | |
| class BaseTicket(BaseModel): | |
| text: str | |
| def qstream(question: Question = Depends()): | |
| output_function = llm_chain | |
| if question.rag: | |
| output_function = rag_chain | |
| q = Queue() | |
| def cb(): | |
| output_function( | |
| {"question": question.text, "chat_history": []}, | |
| callbacks=[QueueCallback(q)], | |
| ) | |
| def generate(): | |
| yield json.dumps({"init": True, "model": llm_name}) | |
| for token, _ in stream(cb, q): | |
| yield json.dumps({"token": token}) | |
| return EventSourceResponse(generate(), media_type="text/event-stream") | |
| async def ask(question: Question = Depends()): | |
| output_function = llm_chain | |
| if question.rag: | |
| output_function = rag_chain | |
| result = output_function( | |
| {"question": question.text, "chat_history": []}, callbacks=[] | |
| ) | |
| return {"result": result["answer"], "model": llm_name} | |
| async def generate_ticket_api(question: BaseTicket = Depends()): | |
| new_title, new_question = generate_ticket( | |
| neo4j_graph=neo4j_graph, | |
| llm_chain=llm_chain, | |
| input_question=question.text, | |
| ) | |
| return {"result": {"title": new_title, "text": new_question}, "model": llm_name} | |