Spaces:
Sleeping
Sleeping
Commit
·
78c941e
1
Parent(s):
212b42a
implement formated prompt
Browse files- app/__pycache__/main.cpython-311.pyc +0 -0
- app/chains.py +3 -2
- app/main.py +30 -7
- app/prompts.py +18 -4
app/__pycache__/main.cpython-311.pyc
CHANGED
|
Binary files a/app/__pycache__/main.cpython-311.pyc and b/app/__pycache__/main.cpython-311.pyc differ
|
|
|
app/chains.py
CHANGED
|
@@ -8,6 +8,7 @@ from langchain_core.runnables import RunnablePassthrough
|
|
| 8 |
import schemas
|
| 9 |
from prompts import (
|
| 10 |
raw_prompt,
|
|
|
|
| 11 |
format_context,
|
| 12 |
tokenizer
|
| 13 |
)
|
|
@@ -29,8 +30,8 @@ simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
|
| 29 |
|
| 30 |
# data_indexer = DataIndexer()
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
|
| 35 |
# # TODO: use history_prompt_formatted and HistoryInput to create the history_chain
|
| 36 |
# history_chain = None
|
|
|
|
| 8 |
import schemas
|
| 9 |
from prompts import (
|
| 10 |
raw_prompt,
|
| 11 |
+
raw_prompt_formatted,
|
| 12 |
format_context,
|
| 13 |
tokenizer
|
| 14 |
)
|
|
|
|
| 30 |
|
| 31 |
# data_indexer = DataIndexer()
|
| 32 |
|
| 33 |
+
# TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
|
| 34 |
+
formatted_chain = (raw_prompt_formatted | llm).with_types(input_type=schemas.UserQuestion)
|
| 35 |
|
| 36 |
# # TODO: use history_prompt_formatted and HistoryInput to create the history_chain
|
| 37 |
# history_chain = None
|
app/main.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import List
|
|
| 10 |
from sqlalchemy.orm import Session
|
| 11 |
|
| 12 |
import schemas
|
| 13 |
-
from chains import simple_chain
|
| 14 |
import crud, models, schemas
|
| 15 |
from database import SessionLocal, engine
|
| 16 |
from callbacks import LogResponseCallback
|
|
@@ -28,24 +28,47 @@ def get_db():
|
|
| 28 |
db.close()
|
| 29 |
|
| 30 |
# ..
|
|
|
|
| 31 |
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
|
| 33 |
data = WellKnownLCSerializer().dumps(output).decode("utf-8")
|
| 34 |
-
yield {'data': data, "event": "data"}
|
|
|
|
|
|
|
|
|
|
| 35 |
yield {"event": "end"}
|
| 36 |
|
| 37 |
-
|
|
|
|
| 38 |
@app.post("/simple/stream")
|
| 39 |
async def simple_stream(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
data = await request.json()
|
| 41 |
user_question = schemas.UserQuestion(**data['input'])
|
|
|
|
|
|
|
| 42 |
return EventSourceResponse(generate_stream(user_question, simple_chain))
|
| 43 |
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
#
|
| 48 |
-
|
|
|
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
# @app.post("/history/stream")
|
|
|
|
| 10 |
from sqlalchemy.orm import Session
|
| 11 |
|
| 12 |
import schemas
|
| 13 |
+
from chains import simple_chain, formatted_chain
|
| 14 |
import crud, models, schemas
|
| 15 |
from database import SessionLocal, engine
|
| 16 |
from callbacks import LogResponseCallback
|
|
|
|
| 28 |
db.close()
|
| 29 |
|
| 30 |
# ..
|
| 31 |
+
# "async" marks the function as asynchronous, allowing it to pause and resume during operations like streaming or I/O.
|
| 32 |
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
|
| 33 |
+
"""generate_stream is an asynchronous generator that processes input data,
|
| 34 |
+
streams output data from a runnable object, serializes each output, and yields
|
| 35 |
+
it to the client in real-time as part of a server-sent event (SSE) stream.
|
| 36 |
+
It uses callbacks to customize the processing, serializes each piece of output
|
| 37 |
+
using WellKnownLCSerializer, and indicates the end of the stream with a final “end” event.
|
| 38 |
+
"""
|
| 39 |
for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
|
| 40 |
data = WellKnownLCSerializer().dumps(output).decode("utf-8")
|
| 41 |
+
yield {'data': data, "event": "data"}
|
| 42 |
+
# After all the data has been streamed and the loop is complete, the function yields a final event to signal
|
| 43 |
+
# the end of the stream. This sends an {"event": "end"} message to the client, letting them know that no more
|
| 44 |
+
# data will be sent.
|
| 45 |
yield {"event": "end"}
|
| 46 |
|
| 47 |
+
# This registers the function simple_stream as a handler for HTTP POST requests at the URL endpoint /simple/stream.
|
| 48 |
+
# It means that when a client sends a POST request to this endpoint, this function will be triggered.
|
| 49 |
@app.post("/simple/stream")
|
| 50 |
async def simple_stream(request: Request):
|
| 51 |
+
"""the function handles a POST request at the /simple/stream endpoint,
|
| 52 |
+
extracts the JSON body, unpacks the "input" field, and then uses it to
|
| 53 |
+
initialize a UserQuestion schema object (which performs validation
|
| 54 |
+
and data transformation) and then initiates a server-sent event response
|
| 55 |
+
to stream data back to the client based on the user’s question.
|
| 56 |
+
"""
|
| 57 |
+
# await is used because parsing the JSON may involve asynchronous I/O operations,
|
| 58 |
+
# especially when handling larger payloads.
|
| 59 |
data = await request.json()
|
| 60 |
user_question = schemas.UserQuestion(**data['input'])
|
| 61 |
+
# This line returns an EventSourceResponse, which is typically used to handle server-sent events (SSE).
|
| 62 |
+
# It’s a special kind of response that streams data back to the client in real time.
|
| 63 |
return EventSourceResponse(generate_stream(user_question, simple_chain))
|
| 64 |
|
| 65 |
|
| 66 |
+
@app.post("/formatted/stream")
|
| 67 |
+
async def formatted_stream(request: Request):
|
| 68 |
+
# TODO: use the formatted_chain to implement the "/formatted/stream" endpoint.
|
| 69 |
+
data = await request.json()
|
| 70 |
+
user_question = schemas.UserQuestion(**data['input'])
|
| 71 |
+
return EventSourceResponse(generate_stream(user_question, formatted_chain))
|
| 72 |
|
| 73 |
|
| 74 |
# @app.post("/history/stream")
|
app/prompts.py
CHANGED
|
@@ -16,15 +16,29 @@ login(os.environ['HF_TOKEN'])
|
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
| 17 |
|
| 18 |
def format_prompt(prompt) -> PromptTemplate:
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def format_chat_history(messages: List[models.Message]):
|
| 24 |
# TODO: implement format_chat_history to format
|
| 25 |
# the list of Message into a text of chat history.
|
| 26 |
raise NotImplemented
|
| 27 |
|
|
|
|
| 28 |
def format_context(docs: List[str]):
|
| 29 |
# TODO: the output of the DataIndexer.search is a list of text,
|
| 30 |
# so we need to concatenate that list into a text that can fit into
|
|
@@ -47,7 +61,7 @@ standalone_prompt: str = None
|
|
| 47 |
rag_prompt: str = None
|
| 48 |
|
| 49 |
# TODO: create raw_prompt_formatted by using format_prompt
|
| 50 |
-
raw_prompt_formatted =
|
| 51 |
raw_prompt = PromptTemplate.from_template(raw_prompt)
|
| 52 |
|
| 53 |
# TODO: use format_prompt to create history_prompt_formatted
|
|
|
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
| 17 |
|
| 18 |
def format_prompt(prompt) -> PromptTemplate:
|
| 19 |
+
"""this function prepares a well-formatted prompt template for interacting with a
|
| 20 |
+
large language model, ensuring that the model has a clear role (AI assistant)
|
| 21 |
+
and understands the user’s input.
|
| 22 |
+
It first 1. format the input prompt by using the model specific instruction template
|
| 23 |
+
2. return a langchain PromptTemplate
|
| 24 |
+
"""
|
| 25 |
+
chat = [
|
| 26 |
+
{"role": "system", "content": "You are a helpful AI assistant."},
|
| 27 |
+
{"role": "user", "content": prompt},
|
| 28 |
+
]
|
| 29 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 30 |
+
chat, # chat-specific formatting template to a conversation input
|
| 31 |
+
tokenize=False, # indicates that the function should return the formatted chat as a string (rather than tokenizing it into numerical tokens).
|
| 32 |
+
add_generation_prompt=True # adds a prompt or marker that signals where the model should start generating the response in a dialogue or conversation flow
|
| 33 |
+
)
|
| 34 |
+
return PromptTemplate.from_template(formatted_prompt)
|
| 35 |
|
| 36 |
def format_chat_history(messages: List[models.Message]):
|
| 37 |
# TODO: implement format_chat_history to format
|
| 38 |
# the list of Message into a text of chat history.
|
| 39 |
raise NotImplemented
|
| 40 |
|
| 41 |
+
|
| 42 |
def format_context(docs: List[str]):
|
| 43 |
# TODO: the output of the DataIndexer.search is a list of text,
|
| 44 |
# so we need to concatenate that list into a text that can fit into
|
|
|
|
| 61 |
rag_prompt: str = None
|
| 62 |
|
| 63 |
# TODO: create raw_prompt_formatted by using format_prompt
|
| 64 |
+
raw_prompt_formatted = format_prompt(raw_prompt)
|
| 65 |
raw_prompt = PromptTemplate.from_template(raw_prompt)
|
| 66 |
|
| 67 |
# TODO: use format_prompt to create history_prompt_formatted
|