jaothan commited on
Commit
e412fa4
·
verified ·
1 Parent(s): e4abcfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -163
app.py CHANGED
@@ -1,163 +1,163 @@
1
- import os
2
-
3
- from langchain_community.graphs import Neo4jGraph
4
- from dotenv import load_dotenv
5
- from utils import (
6
- create_vector_index,
7
- BaseLogger,
8
- )
9
- from chains import (
10
- load_embedding_model,
11
- load_llm,
12
- configure_llm_only_chain,
13
- configure_qa_rag_chain,
14
- generate_ticket,
15
- )
16
- from fastapi import FastAPI, Depends
17
- from pydantic import BaseModel
18
- from langchain.callbacks.base import BaseCallbackHandler
19
- from threading import Thread
20
- from queue import Queue, Empty
21
- from collections.abc import Generator
22
- from sse_starlette.sse import EventSourceResponse
23
- from fastapi.middleware.cors import CORSMiddleware
24
- import json
25
-
26
- load_dotenv(".env")
27
-
28
- url = os.getenv("NEO4J_URI")
29
- username = os.getenv("NEO4J_USERNAME")
30
- password = os.getenv("NEO4J_PASSWORD")
31
- ollama_base_url = os.getenv("OLLAMA_BASE_URL")
32
- embedding_model_name = os.getenv("EMBEDDING_MODEL")
33
- llm_name = os.getenv("LLM")
34
- # Remapping for Langchain Neo4j integration
35
- os.environ["NEO4J_URL"] = url
36
-
37
- embeddings, dimension = load_embedding_model(
38
- embedding_model_name,
39
- config={"ollama_base_url": ollama_base_url},
40
- logger=BaseLogger(),
41
- )
42
-
43
- # if Neo4j is local, you can go to http://localhost:7474/ to browse the database
44
- neo4j_graph = Neo4jGraph(
45
- url=url, username=username, password=password, refresh_schema=False
46
- )
47
- create_vector_index(neo4j_graph)
48
-
49
- llm = load_llm(
50
- llm_name, logger=BaseLogger(), config={"ollama_base_url": ollama_base_url}
51
- )
52
-
53
- llm_chain = configure_llm_only_chain(llm)
54
- rag_chain = configure_qa_rag_chain(
55
- llm, embeddings, embeddings_store_url=url, username=username, password=password
56
- )
57
-
58
-
59
- class QueueCallback(BaseCallbackHandler):
60
- """Callback handler for streaming LLM responses to a queue."""
61
-
62
- def __init__(self, q):
63
- self.q = q
64
-
65
- def on_llm_new_token(self, token: str, **kwargs) -> None:
66
- self.q.put(token)
67
-
68
- def on_llm_end(self, *args, **kwargs) -> None:
69
- return self.q.empty()
70
-
71
-
72
- def stream(cb, q) -> Generator:
73
- job_done = object()
74
-
75
- def task():
76
- x = cb()
77
- q.put(job_done)
78
-
79
- t = Thread(target=task)
80
- t.start()
81
-
82
- content = ""
83
-
84
- # Get each new token from the queue and yield for our generator
85
- while True:
86
- try:
87
- next_token = q.get(True, timeout=1)
88
- if next_token is job_done:
89
- break
90
- content += next_token
91
- yield next_token, content
92
- except Empty:
93
- continue
94
-
95
-
96
- app = FastAPI()
97
- origins = ["*"]
98
-
99
- app.add_middleware(
100
- CORSMiddleware,
101
- allow_origins=origins,
102
- allow_credentials=True,
103
- allow_methods=["*"],
104
- allow_headers=["*"],
105
- )
106
-
107
-
108
- @app.get("/")
109
- async def root():
110
- return {"message": "Hello World"}
111
-
112
-
113
- class Question(BaseModel):
114
- text: str
115
- rag: bool = False
116
-
117
-
118
- class BaseTicket(BaseModel):
119
- text: str
120
-
121
-
122
- @app.get("/query-stream")
123
- def qstream(question: Question = Depends()):
124
- output_function = llm_chain
125
- if question.rag:
126
- output_function = rag_chain
127
-
128
- q = Queue()
129
-
130
- def cb():
131
- output_function(
132
- {"question": question.text, "chat_history": []},
133
- callbacks=[QueueCallback(q)],
134
- )
135
-
136
- def generate():
137
- yield json.dumps({"init": True, "model": llm_name})
138
- for token, _ in stream(cb, q):
139
- yield json.dumps({"token": token})
140
-
141
- return EventSourceResponse(generate(), media_type="text/event-stream")
142
-
143
-
144
- @app.get("/query")
145
- async def ask(question: Question = Depends()):
146
- output_function = llm_chain
147
- if question.rag:
148
- output_function = rag_chain
149
- result = output_function(
150
- {"question": question.text, "chat_history": []}, callbacks=[]
151
- )
152
-
153
- return {"result": result["answer"], "model": llm_name}
154
-
155
-
156
- @app.get("/generate-ticket")
157
- async def generate_ticket_api(question: BaseTicket = Depends()):
158
- new_title, new_question = generate_ticket(
159
- neo4j_graph=neo4j_graph,
160
- llm_chain=llm_chain,
161
- input_question=question.text,
162
- )
163
- return {"result": {"title": new_title, "text": new_question}, "model": llm_name}
 
1
+ import os
2
+
3
+ from langchain_community.graphs import Neo4jGraph
4
+ from dotenv import load_dotenv
5
+ from utils import (
6
+ create_vector_index,
7
+ BaseLogger,
8
+ )
9
+ from chains import (
10
+ load_embedding_model,
11
+ load_llm,
12
+ configure_llm_only_chain,
13
+ configure_qa_rag_chain,
14
+ generate_ticket,
15
+ )
16
+ from fastapi import FastAPI, Depends
17
+ from pydantic import BaseModel
18
+ from langchain.callbacks.base import BaseCallbackHandler
19
+ from threading import Thread
20
+ from queue import Queue, Empty
21
+ from collections.abc import Generator
22
+ from sse_starlette.sse import EventSourceResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ import json
25
+
26
+ load_dotenv(".env")
27
+
28
+ url = os.getenv("NEO4J_URI")
29
+ username = os.getenv("NEO4J_USERNAME")
30
+ password = os.getenv("NEO4J_PASSWORD")
31
+ ollama_base_url = os.getenv("OLLAMA_BASE_URL")
32
+ embedding_model_name = os.getenv("EMBEDDING_MODEL")
33
+ llm_name = os.getenv("LLM")
34
+ # Remapping for Langchain Neo4j integration
35
+ os.environ["NEO4J_URL"] = "http:192.168.178.1:8000" #url
36
+
37
+ embeddings, dimension = load_embedding_model(
38
+ embedding_model_name,
39
+ config={"ollama_base_url": ollama_base_url},
40
+ logger=BaseLogger(),
41
+ )
42
+
43
+ # if Neo4j is local, you can go to http://localhost:7474/ to browse the database
44
+ neo4j_graph = Neo4jGraph(
45
+ url=url, username=username, password=password, refresh_schema=False
46
+ )
47
+ create_vector_index(neo4j_graph)
48
+
49
+ llm = load_llm(
50
+ llm_name, logger=BaseLogger(), config={"ollama_base_url": ollama_base_url}
51
+ )
52
+
53
+ llm_chain = configure_llm_only_chain(llm)
54
+ rag_chain = configure_qa_rag_chain(
55
+ llm, embeddings, embeddings_store_url=url, username=username, password=password
56
+ )
57
+
58
+
59
+ class QueueCallback(BaseCallbackHandler):
60
+ """Callback handler for streaming LLM responses to a queue."""
61
+
62
+ def __init__(self, q):
63
+ self.q = q
64
+
65
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
66
+ self.q.put(token)
67
+
68
+ def on_llm_end(self, *args, **kwargs) -> None:
69
+ return self.q.empty()
70
+
71
+
72
+ def stream(cb, q) -> Generator:
73
+ job_done = object()
74
+
75
+ def task():
76
+ x = cb()
77
+ q.put(job_done)
78
+
79
+ t = Thread(target=task)
80
+ t.start()
81
+
82
+ content = ""
83
+
84
+ # Get each new token from the queue and yield for our generator
85
+ while True:
86
+ try:
87
+ next_token = q.get(True, timeout=1)
88
+ if next_token is job_done:
89
+ break
90
+ content += next_token
91
+ yield next_token, content
92
+ except Empty:
93
+ continue
94
+
95
+
96
+ app = FastAPI()
97
+ origins = ["*"]
98
+
99
+ app.add_middleware(
100
+ CORSMiddleware,
101
+ allow_origins=origins,
102
+ allow_credentials=True,
103
+ allow_methods=["*"],
104
+ allow_headers=["*"],
105
+ )
106
+
107
+
108
+ @app.get("/")
109
+ async def root():
110
+ return {"message": "Hello World"}
111
+
112
+
113
+ class Question(BaseModel):
114
+ text: str
115
+ rag: bool = False
116
+
117
+
118
+ class BaseTicket(BaseModel):
119
+ text: str
120
+
121
+
122
+ @app.get("/query-stream")
123
+ def qstream(question: Question = Depends()):
124
+ output_function = llm_chain
125
+ if question.rag:
126
+ output_function = rag_chain
127
+
128
+ q = Queue()
129
+
130
+ def cb():
131
+ output_function(
132
+ {"question": question.text, "chat_history": []},
133
+ callbacks=[QueueCallback(q)],
134
+ )
135
+
136
+ def generate():
137
+ yield json.dumps({"init": True, "model": llm_name})
138
+ for token, _ in stream(cb, q):
139
+ yield json.dumps({"token": token})
140
+
141
+ return EventSourceResponse(generate(), media_type="text/event-stream")
142
+
143
+
144
+ @app.get("/query")
145
+ async def ask(question: Question = Depends()):
146
+ output_function = llm_chain
147
+ if question.rag:
148
+ output_function = rag_chain
149
+ result = output_function(
150
+ {"question": question.text, "chat_history": []}, callbacks=[]
151
+ )
152
+
153
+ return {"result": result["answer"], "model": llm_name}
154
+
155
+
156
+ @app.get("/generate-ticket")
157
+ async def generate_ticket_api(question: BaseTicket = Depends()):
158
+ new_title, new_question = generate_ticket(
159
+ neo4j_graph=neo4j_graph,
160
+ llm_chain=llm_chain,
161
+ input_question=question.text,
162
+ )
163
+ return {"result": {"title": new_title, "text": new_question}, "model": llm_name}