Spaces:
Paused
Paused
| from typing import List | |
| import fastapi | |
| import markdown | |
| import uvicorn | |
| from ctransformers import AutoModelForCausalLM | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from sse_starlette.sse import EventSourceResponse | |
| from pydantic import BaseModel, Field | |
| from typing_extensions import Literal | |
| from dialogue import DialogueTemplate | |
| llm = AutoModelForCausalLM.from_pretrained("NeoDim/starchat-alpha-GGML", | |
| model_file="starchat-alpha-ggml-q4_0.bin", | |
| model_type="starcoder") | |
| app = fastapi.FastAPI(title="Starchat Alpha") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def index(): | |
| with open("README.md", "r", encoding="utf-8") as readme_file: | |
| md_template_string = readme_file.read() | |
| html_content = markdown.markdown(md_template_string) | |
| return HTMLResponse(content=html_content, status_code=200) | |
| async def demo(): | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/showdown/1.9.1/showdown.min.js"></script> | |
| </head> | |
| <body> | |
| <style> | |
| body { | |
| font-family: -apple-system,BlinkMacSystemFont,"Segoe UI",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji","Segoe UI Symbol"; | |
| } | |
| code { | |
| font-family: "SFMono-Regular",Consolas,"Liberation Mono",Menlo,Courier,monospace !important; | |
| display: inline-block; | |
| background-color: lightgray; | |
| } | |
| h1 h2 h3 h4 h5 h6 { | |
| font-family: Roboto,-apple-system,BlinkMacSystemFont,"Helvetica Neue","Segoe UI","Oxygen","Ubuntu","Cantarell","Open Sans",sans-serif; | |
| } | |
| #content { | |
| box-sizing: border-box; | |
| min-width: 200px; | |
| max-width: 980px; | |
| margin: 0 auto; | |
| padding: 45px; | |
| font-size: 16px; | |
| } | |
| @media (max-width: 767px) { | |
| #content { | |
| padding: 15px; | |
| } | |
| } | |
| </style> | |
| <script type="module" src="https://cdn.skypack.dev/@vanillawc/wc-markdown"></script> | |
| <wc-markdown id="content" highlight><h1>starchat-alpha-q4.0</h1></wc-markdown> | |
| <script> | |
| var converter = new showdown.Converter(); | |
| var source = new EventSource("https://matthoffner-starchat-alpha.hf.space/stream"); | |
| let eventCache; | |
| source.onmessage = function(event) { | |
| let eventData = event.data; | |
| console.log(eventData); | |
| if (eventData.includes("```")) { | |
| eventCache = true; | |
| return; | |
| } | |
| if (eventCache && !eventData.includes("```")) { | |
| backticks = "```"; | |
| eventData = `${backticks}${eventData}<br /><code>`; | |
| eventCache = false; | |
| } | |
| if (eventData === ":") { | |
| eventData = `${eventData}<br />`; | |
| } | |
| if (eventData === "<|assistant|>") { | |
| eventData = `<br />${eventData}`; | |
| } | |
| if (eventData === "<|end|>") { | |
| eventData = "<br />"; | |
| } | |
| document.getElementById("content").innerHTML = document.getElementById("content").innerHTML + eventData; | |
| }; | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content, status_code=200) | |
| async def chat(prompt = "<|user|> Write an express server with server sent events. <|assistant|>"): | |
| tokens = llm.tokenize(prompt) | |
| async def server_sent_events(chat_chunks, llm): | |
| yield prompt | |
| for chat_chunk in llm.generate(chat_chunks): | |
| yield llm.detokenize(chat_chunk) | |
| yield "" | |
| return EventSourceResponse(server_sent_events(tokens, llm)) | |
| class ChatCompletionRequestMessage(BaseModel): | |
| role: Literal["system", "user", "assistant"] = Field( | |
| default="user", description="The role of the message." | |
| ) | |
| content: str = Field(default="", description="The content of the message.") | |
| class ChatCompletionRequest(BaseModel): | |
| messages: List[ChatCompletionRequestMessage] = Field( | |
| default=[], description="A list of messages to generate completions for." | |
| ) | |
| system_message = "Below is a conversation between a human user and a helpful AI coding assistant." | |
| async def chat(request: ChatCompletionRequest, response_mode=None): | |
| kwargs = request.dict() | |
| dialogue_template = DialogueTemplate( | |
| system=system_message, messages=kwargs['messages'] | |
| ) | |
| prompt = dialogue_template.get_inference_prompt() | |
| tokens = llm.tokenize(prompt) | |
| async def server_sent_events(chat_chunks, llm): | |
| for token in llm.generate(chat_chunks): | |
| yield dict(data=llm.detokenize(token)) | |
| yield dict(data="[DONE]") | |
| return EventSourceResponse(server_sent_events(tokens, llm)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |