nabin2004 commited on
Commit
51f9bd2
·
verified ·
1 Parent(s): 06b9974

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +98 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import time
4
+
5
+ from typing import Optional, List
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from starlette.responses import StreamingResponse
10
+ from fastapi import FastAPI, HTTPException, Request
11
+
12
+ app = FastAPI(title="OpenAI-compatible API")
13
+
14
+
15
+ # Load model directly
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
19
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
20
+
21
+
22
+ # data models
23
+ class Message(BaseModel):
24
+ role: str
25
+ content: str
26
+
27
+
28
+ class ChatCompletionRequest(BaseModel):
29
+ model: Optional[str] = "mock-gpt-model"
30
+ messages: List[Message]
31
+ max_tokens: Optional[int] = 512
32
+ temperature: Optional[float] = 0.1
33
+ stream: Optional[bool] = False
34
+
35
+
36
+ async def _resp_async_generator(text_resp: str, model: str):
37
+ tokens = text_resp.split(" ")
38
+
39
+ for i, token in enumerate(tokens):
40
+ chunk = {
41
+ "id": i,
42
+ "object": "chat.completion.chunk",
43
+ "created": time.time(),
44
+ "model": model,
45
+ "choices": [{"delta": {"content": token + " "}}],
46
+ }
47
+ yield f"data: {json.dumps(chunk)}\n\n"
48
+ await asyncio.sleep(0.05)
49
+ yield "data: [DONE]\n\n"
50
+
51
+ @app.post("/chat/completions")
52
+ async def chat_completions(request: ChatCompletionRequest):
53
+ if not request.messages:
54
+ raise HTTPException(status_code=400, detail="No messages provided.")
55
+
56
+ # Build the prompt from messages
57
+ prompt = ""
58
+ for msg in request.messages:
59
+ if msg.role == "user":
60
+ prompt += f"User: {msg.content}\n"
61
+ elif msg.role == "assistant":
62
+ prompt += f"Assistant: {msg.content}\n"
63
+ prompt += "Assistant:"
64
+
65
+ # Tokenize and generate
66
+ inputs = tokenizer(prompt, return_tensors="pt")
67
+ outputs = model.generate(
68
+ **inputs,
69
+ max_new_tokens=request.max_tokens,
70
+ temperature=request.temperature,
71
+ do_sample=True,
72
+ pad_token_id=tokenizer.eos_token_id
73
+ )
74
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+
76
+ # Extract only the assistant's new reply
77
+ assistant_reply = generated_text[len(prompt):].strip()
78
+
79
+ if request.stream:
80
+ return StreamingResponse(
81
+ _resp_async_generator(assistant_reply, request.model),
82
+ media_type="text/event-stream"
83
+ )
84
+
85
+ return {
86
+ "id": "1337",
87
+ "object": "chat.completion",
88
+ "created": time.time(),
89
+ "model": request.model,
90
+ "choices": [{"message": Message(role="assistant", content=assistant_reply)}],
91
+ }
92
+
93
+
94
+
95
+
96
+ # if __name__ == "__main__":
97
+ # import uvicorn
98
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torch
5
+ pydantic
6
+ starlette