Jadyro commited on
Commit
f1dbca3
·
verified ·
1 Parent(s): 4e9db80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -41
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from typing import List, Optional
4
  from transformers import AutoTokenizer, pipeline
@@ -11,7 +11,7 @@ pipe = pipeline(
11
  "text-generation",
12
  model=MODEL_ID,
13
  tokenizer=tokenizer,
14
- device_map="auto", # "cpu" if you want to force CPU
15
  max_new_tokens=512,
16
  )
17
 
@@ -24,45 +24,85 @@ class ChatMessage(BaseModel):
24
 
25
 
26
  class ChatRequest(BaseModel):
27
- model: Optional[str] = None # ignored, for OpenAI-compat
28
  messages: List[ChatMessage]
29
  temperature: Optional[float] = 0.0
30
  max_tokens: Optional[int] = 512
31
 
32
 
33
- class ChatChoiceMessage(BaseModel):
34
- role: str
35
- content: str
36
-
37
-
38
- class ChatChoice(BaseModel):
39
- index: int
40
- message: ChatChoiceMessage
41
- finish_reason: str
42
-
43
-
44
- class ChatResponse(BaseModel):
45
- id: str
46
- object: str
47
- choices: List[ChatChoice]
48
-
49
-
50
  @app.get("/")
51
  def root():
52
  return {"status": "ok", "model": MODEL_ID}
53
 
54
 
55
- @app.post("/v1/chat/completions", response_model=ChatResponse)
56
- def chat(request: ChatRequest):
57
- # Convert Pydantic objects to plain dicts
58
- messages = [m.dict() for m in request.messages]
59
-
60
- # Use the model's chat template
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  prompt = tokenizer.apply_chat_template(
62
- messages,
63
  tokenize=False,
64
  add_generation_prompt=True,
65
  )
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  outputs = pipe(
68
  prompt,
@@ -75,18 +115,17 @@ def chat(request: ChatRequest):
75
  full = outputs[0]["generated_text"]
76
  reply = full[len(prompt):].strip()
77
 
78
- response = ChatResponse(
79
- id="chatcmpl-1",
80
- object="chat.completion",
81
- choices=[
82
- ChatChoice(
83
- index=0,
84
- message=ChatChoiceMessage(
85
- role="assistant",
86
- content=reply,
87
- ),
88
- finish_reason="stop",
89
- )
90
  ],
91
- )
92
- return response
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import List, Optional
4
  from transformers import AutoTokenizer, pipeline
 
11
  "text-generation",
12
  model=MODEL_ID,
13
  tokenizer=tokenizer,
14
+ device_map="auto", # "cpu" on HF’s free tier
15
  max_new_tokens=512,
16
  )
17
 
 
24
 
25
 
26
  class ChatRequest(BaseModel):
27
+ model: Optional[str] = None # ignored, OpenAI-style compat
28
  messages: List[ChatMessage]
29
  temperature: Optional[float] = 0.0
30
  max_tokens: Optional[int] = 512
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @app.get("/")
34
  def root():
35
  return {"status": "ok", "model": MODEL_ID}
36
 
37
 
38
+ def build_prompt(raw_messages: List[dict]) -> str:
39
+ """
40
+ Normalize messages so they fit the template:
41
+ - Collect system messages and prepend their text to the first user message.
42
+ - Drop leading assistant messages.
43
+ - Merge consecutive messages with the same role.
44
+ - Ensure we end up with user/assistant/user/assistant/... only.
45
+ """
46
+
47
+ system_parts = []
48
+ ua_messages = []
49
+
50
+ # Separate system vs user/assistant
51
+ for m in raw_messages:
52
+ role = m.get("role")
53
+ content = m.get("content", "")
54
+ if role == "system":
55
+ if content:
56
+ system_parts.append(content)
57
+ elif role in ("user", "assistant"):
58
+ ua_messages.append({"role": role, "content": content})
59
+ # ignore anything else
60
+
61
+ # Drop leading assistants (template wants to start with user)
62
+ while ua_messages and ua_messages[0]["role"] != "user":
63
+ ua_messages.pop(0)
64
+
65
+ # Merge consecutive messages with same role
66
+ normalized: List[dict] = []
67
+ for m in ua_messages:
68
+ if not normalized:
69
+ normalized.append(m)
70
+ else:
71
+ if normalized[-1]["role"] == m["role"]:
72
+ normalized[-1]["content"] += "\n\n" + m["content"]
73
+ else:
74
+ normalized.append(m)
75
+
76
+ if not normalized:
77
+ raise ValueError("No user messages found after normalization.")
78
+
79
+ # Prepend system text into the first user message, if any
80
+ if system_parts:
81
+ system_text = "\n\n".join(system_parts)
82
+ if normalized[0]["role"] == "user":
83
+ normalized[0]["content"] = system_text + "\n\n" + normalized[0]["content"]
84
+ else:
85
+ # If for some reason first is assistant, prepend a synthetic user
86
+ normalized.insert(0, {"role": "user", "content": system_text})
87
+
88
+ # At this point we should only have user/assistant alternating.
89
+ # Let tokenizer.apply_chat_template enforce the exact format.
90
  prompt = tokenizer.apply_chat_template(
91
+ normalized,
92
  tokenize=False,
93
  add_generation_prompt=True,
94
  )
95
+ return prompt
96
+
97
+
98
+ @app.post("/v1/chat/completions")
99
+ def chat(request: ChatRequest):
100
+ try:
101
+ messages = [m.dict() for m in request.messages]
102
+ prompt = build_prompt(messages)
103
+ except Exception as e:
104
+ # Don't crash the app – return a 400 with explanation
105
+ raise HTTPException(status_code=400, detail=f"Invalid message history: {e}")
106
 
107
  outputs = pipe(
108
  prompt,
 
115
  full = outputs[0]["generated_text"]
116
  reply = full[len(prompt):].strip()
117
 
118
+ return {
119
+ "id": "chatcmpl-1",
120
+ "object": "chat.completion",
121
+ "choices": [
122
+ {
123
+ "index": 0,
124
+ "message": {
125
+ "role": "assistant",
126
+ "content": reply,
127
+ },
128
+ "finish_reason": "stop",
129
+ }
130
  ],
131
+ }