jaothan commited on
Commit
e4abcfb
·
verified ·
1 Parent(s): b8dc493

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_community.graphs import Neo4jGraph
4
+ from dotenv import load_dotenv
5
+ from utils import (
6
+ create_vector_index,
7
+ BaseLogger,
8
+ )
9
+ from chains import (
10
+ load_embedding_model,
11
+ load_llm,
12
+ configure_llm_only_chain,
13
+ configure_qa_rag_chain,
14
+ generate_ticket,
15
+ )
16
+ from fastapi import FastAPI, Depends
17
+ from pydantic import BaseModel
18
+ from langchain.callbacks.base import BaseCallbackHandler
19
+ from threading import Thread
20
+ from queue import Queue, Empty
21
+ from collections.abc import Generator
22
+ from sse_starlette.sse import EventSourceResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ import json
25
+
26
+ load_dotenv(".env")
27
+
28
+ url = os.getenv("NEO4J_URI")
29
+ username = os.getenv("NEO4J_USERNAME")
30
+ password = os.getenv("NEO4J_PASSWORD")
31
+ ollama_base_url = os.getenv("OLLAMA_BASE_URL")
32
+ embedding_model_name = os.getenv("EMBEDDING_MODEL")
33
+ llm_name = os.getenv("LLM")
34
+ # Remapping for Langchain Neo4j integration
35
+ os.environ["NEO4J_URL"] = url
36
+
37
+ embeddings, dimension = load_embedding_model(
38
+ embedding_model_name,
39
+ config={"ollama_base_url": ollama_base_url},
40
+ logger=BaseLogger(),
41
+ )
42
+
43
+ # if Neo4j is local, you can go to http://localhost:7474/ to browse the database
44
+ neo4j_graph = Neo4jGraph(
45
+ url=url, username=username, password=password, refresh_schema=False
46
+ )
47
+ create_vector_index(neo4j_graph)
48
+
49
+ llm = load_llm(
50
+ llm_name, logger=BaseLogger(), config={"ollama_base_url": ollama_base_url}
51
+ )
52
+
53
+ llm_chain = configure_llm_only_chain(llm)
54
+ rag_chain = configure_qa_rag_chain(
55
+ llm, embeddings, embeddings_store_url=url, username=username, password=password
56
+ )
57
+
58
+
59
+ class QueueCallback(BaseCallbackHandler):
60
+ """Callback handler for streaming LLM responses to a queue."""
61
+
62
+ def __init__(self, q):
63
+ self.q = q
64
+
65
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
66
+ self.q.put(token)
67
+
68
+ def on_llm_end(self, *args, **kwargs) -> None:
69
+ return self.q.empty()
70
+
71
+
72
+ def stream(cb, q) -> Generator:
73
+ job_done = object()
74
+
75
+ def task():
76
+ x = cb()
77
+ q.put(job_done)
78
+
79
+ t = Thread(target=task)
80
+ t.start()
81
+
82
+ content = ""
83
+
84
+ # Get each new token from the queue and yield for our generator
85
+ while True:
86
+ try:
87
+ next_token = q.get(True, timeout=1)
88
+ if next_token is job_done:
89
+ break
90
+ content += next_token
91
+ yield next_token, content
92
+ except Empty:
93
+ continue
94
+
95
+
96
+ app = FastAPI()
97
+ origins = ["*"]
98
+
99
+ app.add_middleware(
100
+ CORSMiddleware,
101
+ allow_origins=origins,
102
+ allow_credentials=True,
103
+ allow_methods=["*"],
104
+ allow_headers=["*"],
105
+ )
106
+
107
+
108
+ @app.get("/")
109
+ async def root():
110
+ return {"message": "Hello World"}
111
+
112
+
113
+ class Question(BaseModel):
114
+ text: str
115
+ rag: bool = False
116
+
117
+
118
+ class BaseTicket(BaseModel):
119
+ text: str
120
+
121
+
122
+ @app.get("/query-stream")
123
+ def qstream(question: Question = Depends()):
124
+ output_function = llm_chain
125
+ if question.rag:
126
+ output_function = rag_chain
127
+
128
+ q = Queue()
129
+
130
+ def cb():
131
+ output_function(
132
+ {"question": question.text, "chat_history": []},
133
+ callbacks=[QueueCallback(q)],
134
+ )
135
+
136
+ def generate():
137
+ yield json.dumps({"init": True, "model": llm_name})
138
+ for token, _ in stream(cb, q):
139
+ yield json.dumps({"token": token})
140
+
141
+ return EventSourceResponse(generate(), media_type="text/event-stream")
142
+
143
+
144
+ @app.get("/query")
145
+ async def ask(question: Question = Depends()):
146
+ output_function = llm_chain
147
+ if question.rag:
148
+ output_function = rag_chain
149
+ result = output_function(
150
+ {"question": question.text, "chat_history": []}, callbacks=[]
151
+ )
152
+
153
+ return {"result": result["answer"], "model": llm_name}
154
+
155
+
156
+ @app.get("/generate-ticket")
157
+ async def generate_ticket_api(question: BaseTicket = Depends()):
158
+ new_title, new_question = generate_ticket(
159
+ neo4j_graph=neo4j_graph,
160
+ llm_chain=llm_chain,
161
+ input_question=question.text,
162
+ )
163
+ return {"result": {"title": new_title, "text": new_question}, "model": llm_name}