wang16888 commited on
Commit
d39cedf
·
verified ·
1 Parent(s): 34ef624

Upload 9 files

Browse files
Files changed (9) hide show
  1. callbacks.py +24 -0
  2. chains.py +54 -0
  3. crud.py +27 -0
  4. data_indexing.py +154 -0
  5. database.py +12 -0
  6. main.py +118 -0
  7. models.py +23 -0
  8. prompts.py +70 -0
  9. schemas.py +26 -0
callbacks.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, Any, List
3
+ from langchain_core.callbacks import BaseCallbackHandler
4
+ import schemas
5
+ import crud
6
+
7
+
8
+ class LogResponseCallback(BaseCallbackHandler):
9
+
10
+ def __init__(self, user_request: schemas.UserRequest, db):
11
+ super().__init__()
12
+ self.user_request = user_request
13
+ self.db = db
14
+
15
+ def on_llm_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
16
+ """Run when chain ends running."""
17
+ message = schemas.MessageBase(message=outputs.generations[0][0].text, type='AI')
18
+ crud.add_message(self.db, message, self.user_request.username)
19
+
20
+ def on_llm_start(
21
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
22
+ ) -> Any:
23
+ for prompt in prompts:
24
+ print(prompt)
chains.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_huggingface import HuggingFaceEndpoint
3
+ from langchain_core.runnables import RunnablePassthrough
4
+ import schemas
5
+ from prompts import (
6
+ raw_prompt,
7
+ raw_prompt_formatted,
8
+ history_prompt_formatted,
9
+ standalone_prompt_formatted,
10
+ rag_prompt_formatted,
11
+ format_context,
12
+ tokenizer
13
+ )
14
+ from data_indexing import DataIndexer
15
+
16
+
17
+ data_indexer = DataIndexer()
18
+
19
+
20
+ llm = HuggingFaceEndpoint(
21
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
22
+ huggingfacehub_api_token=os.environ['HF_TOKEN'],
23
+ max_new_tokens=512,
24
+ stop_sequences=[tokenizer.eos_token],
25
+ streaming=True,
26
+ )
27
+
28
+ simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
29
+
30
+ formatted_chain = (
31
+ raw_prompt_formatted
32
+ | llm
33
+ ).with_types(input_type=schemas.UserQuestion)
34
+
35
+ history_chain = (
36
+ history_prompt_formatted
37
+ | llm
38
+ ).with_types(input_type=schemas.HistoryInput)
39
+
40
+ rag_chain = (
41
+ RunnablePassthrough.assign(new_question=standalone_prompt_formatted | llm)
42
+ | {
43
+ 'context': lambda x: format_context(data_indexer.search(x['new_question'], hybrid_search=x['hybrid_search'])),
44
+ 'standalone_question': lambda x: x['new_question'],
45
+ 'test': lambda x : print(x)
46
+ }
47
+ | rag_prompt_formatted
48
+ | llm
49
+ ).with_types(input_type=schemas.RagInput)
50
+
51
+
52
+
53
+
54
+
crud.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.orm import Session
2
+ import models, schemas
3
+
4
+
5
+ def get_or_create_user(db: Session, username: str):
6
+ user = db.query(models.User).filter(models.User.username == username).first()
7
+ if not user:
8
+ user = models.User(username=username)
9
+ db.add(user)
10
+ db.commit()
11
+ db.refresh(user)
12
+ return user
13
+
14
+ def add_message(db: Session, message: schemas.MessageBase, username: str):
15
+ user = get_or_create_user(db, username)
16
+ message = models.Message(**message.dict())
17
+ message.user = user
18
+ db.add(message)
19
+ db.commit()
20
+ db.refresh(message)
21
+ return message
22
+
23
+ def get_user_chat_history(db: Session, username: str):
24
+ user = db.query(models.User).filter(models.User.username == username).first()
25
+ if not user:
26
+ return []
27
+ return user.messages
data_indexing.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from pathlib import Path
4
+ from pinecone.grpc import PineconeGRPC as Pinecone
5
+ from pinecone import ServerlessSpec
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain_openai import OpenAIEmbeddings
8
+
9
+ current_dir = Path(__file__).resolve().parent
10
+
11
+
12
+ os.environ['PINECONE_API_KEY'] = "988da8ab-3725-4047-b622-cc42d07ecb6c"
13
+ os.environ['OPENAI_API_KEY'] = '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'
14
+
15
+
16
+ class DataIndexer:
17
+
18
+ source_file = os.path.join(current_dir, 'sources.txt')
19
+
20
+ def __init__(self, index_name='langchain-repo') -> None:
21
+ # self.embedding_client = InferenceClient(
22
+ # "dunzhang/stella_en_1.5B_v5",
23
+ # )
24
+ self.embedding_client = OpenAIEmbeddings()
25
+
26
+ self.index_name = index_name
27
+ self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
28
+
29
+ if index_name not in self.pinecone_client.list_indexes().names():
30
+ self.pinecone_client.create_index(
31
+ name=index_name,
32
+ dimension=1536,
33
+ metric='cosine',
34
+ spec=ServerlessSpec(
35
+ cloud='aws',
36
+ region='us-east-1'
37
+ )
38
+ )
39
+
40
+ self.index = self.pinecone_client.Index(self.index_name)
41
+ # self.source_index = self.get_source_index()
42
+ self.source_index = None
43
+
44
+ def get_source_index(self):
45
+ if not os.path.isfile(self.source_file):
46
+ print('No source file')
47
+ return None
48
+
49
+ print('create source index')
50
+
51
+ with open(self.source_file, 'r') as file:
52
+ sources = file.readlines()
53
+
54
+ sources = [s.rstrip('\n') for s in sources]
55
+ vectorstore = Chroma.from_texts(
56
+ sources, embedding=self.embedding_client
57
+ )
58
+ return vectorstore
59
+
60
+ def index_data(self, docs, batch_size=32):
61
+
62
+ with open(self.source_file, 'a') as file:
63
+ for doc in docs:
64
+ file.writelines(doc.metadata['source'] + '\n')
65
+
66
+ for i in range(0, len(docs), batch_size):
67
+ batch = docs[i: i + batch_size]
68
+ values = self.embedding_client.embed_documents([
69
+ doc.page_content for doc in batch
70
+ ])
71
+
72
+ # values = self.embedding_client.feature_extraction([
73
+ # doc.page_content for doc in batch
74
+ # ])
75
+ vector_ids = [str(uuid.uuid4()) for _ in batch]
76
+
77
+ metadatas = [{
78
+ 'text': doc.page_content,
79
+ **doc.metadata
80
+ } for doc in batch]
81
+
82
+ vectors = [{
83
+ 'id': vector_id,
84
+ 'values': value,
85
+ 'metadata': metadata
86
+ } for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
87
+
88
+ try:
89
+ upsert_response = self.index.upsert(vectors=vectors)
90
+ print(upsert_response)
91
+ except Exception as e:
92
+ print(e)
93
+
94
+ def search(self, text_query, top_k=5, hybrid_search=False):
95
+
96
+ print('text query:', text_query)
97
+
98
+ filter = None
99
+ if hybrid_search and self.source_index:
100
+ source_docs = self.source_index.similarity_search(text_query, 50)
101
+ print("source_docs", source_docs)
102
+ filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
103
+
104
+ # vector = self.embedding_client.feature_extraction(text_query)
105
+ vector = self.embedding_client.embed_query(text_query)
106
+ result = self.index.query(
107
+ vector=vector,
108
+ top_k=top_k,
109
+ include_metadata=True,
110
+ filter=filter
111
+ )
112
+
113
+ docs = []
114
+ for res in result["matches"]:
115
+ metadata = res["metadata"]
116
+ if 'text' in metadata:
117
+ text = metadata.pop('text')
118
+ docs.append(text)
119
+ return docs
120
+
121
+
122
+ if __name__ == '__main__':
123
+
124
+ from langchain_community.document_loaders import GitLoader
125
+ from langchain_text_splitters import (
126
+ Language,
127
+ RecursiveCharacterTextSplitter,
128
+ )
129
+
130
+ loader = GitLoader(
131
+ clone_url="https://github.com/langchain-ai/langchain",
132
+ repo_path="./code_data/langchain_repo/",
133
+ branch="master",
134
+ )
135
+
136
+ python_splitter = RecursiveCharacterTextSplitter.from_language(
137
+ language=Language.PYTHON, chunk_size=10000, chunk_overlap=100
138
+ )
139
+
140
+ docs = loader.load()
141
+ docs = [doc for doc in docs if doc.metadata['file_type'] in ['.py', '.md']]
142
+ docs = [doc for doc in docs if len(doc.page_content) < 50000]
143
+ docs = python_splitter.split_documents(docs)
144
+ for doc in docs:
145
+ doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
146
+
147
+ indexer = DataIndexer()
148
+ with open('/app/sources.txt', 'a') as file:
149
+ for doc in docs:
150
+ file.writelines(doc.metadata['source'] + '\n')
151
+ print('DONE')
152
+ indexer.index_data(docs)
153
+
154
+
database.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import sessionmaker
4
+
5
+ SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
6
+
7
+ engine = create_engine(
8
+ SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
9
+ )
10
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
11
+
12
+ Base = declarative_base()
main.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables import Runnable
2
+ from langchain_core.callbacks import BaseCallbackHandler
3
+ from fastapi import FastAPI, Request, Depends
4
+ from sse_starlette.sse import EventSourceResponse
5
+ from sqlalchemy.orm import Session
6
+ from langserve.serialization import WellKnownLCSerializer
7
+ from typing import Any, List
8
+ import crud, models, schemas
9
+ from database import SessionLocal, engine
10
+ from chains import simple_chain, formatted_chain, history_chain, rag_chain
11
+ from prompts import format_chat_history
12
+ from callbacks import LogResponseCallback
13
+
14
+
15
+ models.Base.metadata.create_all(bind=engine)
16
+
17
+ app = FastAPI()
18
+
19
+ def get_db():
20
+ db = SessionLocal()
21
+ try:
22
+ yield db
23
+ finally:
24
+ db.close()
25
+
26
+
27
+ async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
28
+ for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
29
+ data = WellKnownLCSerializer().dumps(output).decode("utf-8")
30
+ yield {'data': data, "event": "data"}
31
+ yield {"event": "end"}
32
+
33
+
34
+ @app.get("/")
35
+ def greet_json():
36
+ return {"Hello": "World!"}
37
+
38
+
39
+ @app.post("/simple/stream")
40
+ async def simple_stream(request: Request):
41
+ data = await request.json()
42
+ user_question = schemas.UserQuestion(**data['input'])
43
+ return EventSourceResponse(generate_stream(user_question, simple_chain))
44
+
45
+
46
+ @app.post("/formatted/stream")
47
+ async def formatted_stream(request: Request):
48
+ data = await request.json()
49
+ user_question = schemas.UserQuestion(**data['input'])
50
+ return EventSourceResponse(generate_stream(user_question, formatted_chain))
51
+
52
+
53
+ @app.post("/history/stream")
54
+ async def history_stream(request: Request, db: Session = Depends(get_db)):
55
+ data = await request.json()
56
+ user_request = schemas.UserRequest(**data['input'])
57
+ chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
58
+ message = schemas.MessageBase(message=user_request.question, type='User')
59
+ crud.add_message(db, message, user_request.username)
60
+
61
+ history_input = schemas.HistoryInput(
62
+ question=user_request.question,
63
+ chat_history=format_chat_history(chat_history)
64
+ )
65
+
66
+ return EventSourceResponse(generate_stream(
67
+ history_input,
68
+ history_chain,
69
+ [LogResponseCallback(user_request, db)]
70
+ ))
71
+
72
+
73
+ @app.post("/rag/stream")
74
+ async def rag_stream(request: Request, db: Session = Depends(get_db)):
75
+ data = await request.json()
76
+ user_request = schemas.UserRequest(**data['input'])
77
+ chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
78
+ message = schemas.MessageBase(message=user_request.question, type='User')
79
+ crud.add_message(db, message, user_request.username)
80
+
81
+ rag_input = schemas.RagInput(
82
+ question=user_request.question,
83
+ chat_history=format_chat_history(chat_history),
84
+ )
85
+
86
+ return EventSourceResponse(generate_stream(
87
+ rag_input,
88
+ rag_chain,
89
+ [LogResponseCallback(user_request, db)]
90
+ ))
91
+
92
+ @app.post("/filtered_rag/stream")
93
+ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
94
+ data = await request.json()
95
+ print(data)
96
+ user_request = schemas.UserRequest(**data['input'])
97
+ chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
98
+ message = schemas.MessageBase(message=user_request.question, type='User')
99
+ crud.add_message(db, message, user_request.username)
100
+
101
+ rag_input = schemas.RagInput(
102
+ question=user_request.question,
103
+ chat_history=format_chat_history(chat_history),
104
+ hybrid_search=True
105
+ )
106
+
107
+ return EventSourceResponse(generate_stream(
108
+ rag_input,
109
+ rag_chain,
110
+ [LogResponseCallback(user_request, db)]
111
+ ))
112
+
113
+
114
+
115
+
116
+ if __name__ == "__main__":
117
+ import uvicorn
118
+ uvicorn.run("main:app", host="localhost", reload=True, port=8002)
models.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, ForeignKey, Integer, String, DateTime
2
+ from sqlalchemy.orm import relationship
3
+ from datetime import datetime
4
+
5
+ from database import Base
6
+
7
+ class User(Base):
8
+ __tablename__ = "users"
9
+
10
+ id = Column(Integer, primary_key=True, index=True)
11
+ username = Column(String, unique=True, index=True)
12
+ messages = relationship("Message", back_populates="user")
13
+
14
+ class Message(Base):
15
+ __tablename__ = "messages"
16
+
17
+ id = Column(Integer, primary_key=True, index=True)
18
+ user_id = Column(Integer, ForeignKey("users.id"))
19
+ message = Column(String)
20
+ type = Column(String)
21
+ timestamp = Column(DateTime, default=datetime.now)
22
+
23
+ user = relationship("User", back_populates="messages")
prompts.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from langchain_core.prompts import PromptTemplate
3
+ from typing import List
4
+ import models
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
7
+
8
+ raw_prompt = "{question}"
9
+
10
+ history_prompt = """
11
+ Given the following conversation provide a helpful answer to the follow up question.
12
+
13
+ Chat History:
14
+ {chat_history}
15
+
16
+ Follow Up question: {question}
17
+ helpful answer:
18
+ """
19
+
20
+ standalone_prompt = """
21
+ Given the following conversation and a follow up question, rephrase the
22
+ follow up question to be a standalone question, in its original language.
23
+
24
+ Chat History:
25
+ {chat_history}
26
+
27
+ Follow Up Input: {question}
28
+
29
+ Standalone question:
30
+ """
31
+
32
+ rag_prompt = """
33
+ Answer the question based only on the following context:
34
+ {context}
35
+
36
+ Question: {standalone_question}
37
+ """
38
+
39
+
40
+ def format_prompt(prompt):
41
+ chat = [
42
+ {"role": "system", "content": "You are a helpful AI assistant."},
43
+ {"role": "user", "content": prompt},
44
+ ]
45
+
46
+ formatted_prompt = tokenizer.apply_chat_template(
47
+ chat,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
+ )
51
+
52
+ return PromptTemplate.from_template(formatted_prompt)
53
+
54
+
55
+ def format_chat_history(messages: List[models.Message]):
56
+ return '\n'.join([
57
+ '{}: {}'.format(message.type, message.message)
58
+ for message in messages
59
+ ])
60
+
61
+ def format_context(docs: List[str]):
62
+ return '\n\n'.join(docs)
63
+
64
+
65
+ raw_prompt_formatted = format_prompt(raw_prompt)
66
+ raw_prompt = PromptTemplate.from_template(raw_prompt)
67
+ history_prompt_formatted = format_prompt(history_prompt)
68
+ standalone_prompt_formatted = format_prompt(standalone_prompt)
69
+ rag_prompt_formatted = format_prompt(rag_prompt)
70
+
schemas.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic.v1 import BaseModel
2
+ from datetime import datetime
3
+ from typing import Optional
4
+
5
+ class UserQuestion(BaseModel):
6
+ question: str
7
+
8
+ class UserRequest(UserQuestion):
9
+ username: str
10
+
11
+ class HistoryInput(BaseModel):
12
+ chat_history: str
13
+ question: str
14
+
15
+ class RagInput(HistoryInput):
16
+ hybrid_search: bool = False
17
+
18
+ class MessageBase(BaseModel):
19
+ id: Optional[int] = None
20
+ user_id: Optional[int] = None
21
+ message: str
22
+ type: str
23
+ timestamp: Optional[datetime] = None
24
+
25
+ class Config:
26
+ orm_mode = True