1MR commited on
Commit
f4c2faa
·
verified ·
1 Parent(s): 421bbb8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -0
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional, Dict, Any
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import uvicorn
7
+ import logging
8
+ from contextlib import asynccontextmanager
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Global variables for model and tokenizer
15
+ model = None
16
+ tokenizer = None
17
+
18
+ # Request/Response models
19
+ class ChatMessage(BaseModel):
20
+ role: str # "system", "user", "assistant"
21
+ content: str
22
+
23
+ class ChatRequest(BaseModel):
24
+ messages: List[ChatMessage]
25
+ max_tokens: Optional[int] = 512
26
+ temperature: Optional[float] = 0.7
27
+ top_p: Optional[float] = 0.9
28
+ stop: Optional[List[str]] = None
29
+
30
+ class ChatResponse(BaseModel):
31
+ content: str
32
+ finish_reason: str
33
+ usage: Dict[str, int]
34
+
35
+ class ChatStreamChunk(BaseModel):
36
+ content: str
37
+ finish_reason: Optional[str] = None
38
+ usage: Optional[Dict[str, int]] = None
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ # Load model on startup
43
+ global model, tokenizer
44
+ logger.info("Loading model and tokenizer...")
45
+
46
+ # Replace with your model path/name
47
+ model_name = "Qwen/Qwen3-4B" # or local path
48
+ # model_name = "your-username/your-fine-tuned-model" # or local path
49
+
50
+ try:
51
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_name,
54
+ torch_dtype=torch.float16,
55
+ device_map="auto",
56
+ trust_remote_code=True
57
+ )
58
+
59
+ # Set pad token if not present
60
+ if tokenizer.pad_token is None:
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+
63
+ logger.info(f"Model loaded successfully: {model_name}")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Failed to load model: {e}")
67
+ raise e
68
+
69
+ yield
70
+
71
+ # Cleanup
72
+ logger.info("Shutting down...")
73
+
74
+ # Initialize FastAPI app
75
+ app = FastAPI(
76
+ title="Custom Chat Model API",
77
+ description="API for fine-tuned chat model",
78
+ version="1.0.0",
79
+ lifespan=lifespan
80
+ )
81
+
82
+ def format_messages(messages: List[ChatMessage]) -> str:
83
+ """Format messages into a prompt string"""
84
+ formatted_prompt = ""
85
+
86
+ for message in messages:
87
+ if message.role == "system":
88
+ formatted_prompt += f"System: {message.content}\n"
89
+ elif message.role == "user":
90
+ formatted_prompt += f"User: {message.content}\n"
91
+ elif message.role == "assistant":
92
+ formatted_prompt += f"Assistant: {message.content}\n"
93
+
94
+ # Add assistant prompt for completion
95
+ formatted_prompt += "Assistant:"
96
+ return formatted_prompt
97
+
98
+ def generate_response(
99
+ prompt: str,
100
+ max_tokens: int = 512,
101
+ temperature: float = 0.7,
102
+ top_p: float = 0.9,
103
+ stop: Optional[List[str]] = None
104
+ ) -> tuple[str, Dict[str, int]]:
105
+ """Generate response using the loaded model"""
106
+
107
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
108
+ input_ids = inputs["input_ids"].to(model.device)
109
+ attention_mask = inputs["attention_mask"].to(model.device)
110
+
111
+ input_length = input_ids.shape[1]
112
+
113
+ # Generate response
114
+ with torch.no_grad():
115
+ outputs = model.generate(
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask,
118
+ max_new_tokens=max_tokens,
119
+ temperature=temperature,
120
+ top_p=top_p,
121
+ do_sample=True,
122
+ pad_token_id=tokenizer.pad_token_id,
123
+ eos_token_id=tokenizer.eos_token_id,
124
+ repetition_penalty=1.1
125
+ )
126
+
127
+ # Decode only the generated part
128
+ generated_ids = outputs[0][input_length:]
129
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True)
130
+
131
+ # Handle stop tokens
132
+ if stop:
133
+ for stop_token in stop:
134
+ if stop_token in response:
135
+ response = response.split(stop_token)[0]
136
+ break
137
+
138
+ # Calculate tokens
139
+ output_tokens = len(tokenizer.encode(response))
140
+ usage = {
141
+ "input_tokens": input_length,
142
+ "output_tokens": output_tokens,
143
+ "total_tokens": input_length + output_tokens
144
+ }
145
+
146
+ return response.strip(), usage
147
+
148
+ @app.get("/")
149
+ async def root():
150
+ return {"message": "Custom Chat Model API", "status": "running"}
151
+
152
+ @app.get("/health")
153
+ async def health_check():
154
+ return {"status": "healthy", "model_loaded": model is not None}
155
+
156
+ @app.post("/chat/completions", response_model=ChatResponse)
157
+ async def chat_completions(request: ChatRequest):
158
+ """Main chat completion endpoint"""
159
+
160
+ if model is None or tokenizer is None:
161
+ raise HTTPException(status_code=503, detail="Model not loaded")
162
+
163
+ try:
164
+ # Format messages into prompt
165
+ prompt = format_messages(request.messages)
166
+
167
+ # Generate response
168
+ response_content, usage = generate_response(
169
+ prompt=prompt,
170
+ max_tokens=request.max_tokens,
171
+ temperature=request.temperature,
172
+ top_p=request.top_p,
173
+ stop=request.stop
174
+ )
175
+
176
+ return ChatResponse(
177
+ content=response_content,
178
+ finish_reason="stop",
179
+ usage=usage
180
+ )
181
+
182
+ except Exception as e:
183
+ logger.error(f"Error in chat completion: {e}")
184
+ raise HTTPException(status_code=500, detail=str(e))
185
+
186
+ @app.post("/chat/stream")
187
+ async def chat_stream(request: ChatRequest):
188
+ """Streaming chat completion endpoint"""
189
+
190
+ if model is None or tokenizer is None:
191
+ raise HTTPException(status_code=503, detail="Model not loaded")
192
+
193
+ try:
194
+ from fastapi.responses import StreamingResponse
195
+ import json
196
+
197
+ def generate_stream():
198
+ prompt = format_messages(request.messages)
199
+
200
+ # For simplicity, we'll simulate streaming by chunking the response
201
+ # In a real implementation, you'd use model.generate with streaming
202
+ response_content, usage = generate_response(
203
+ prompt=prompt,
204
+ max_tokens=request.max_tokens,
205
+ temperature=request.temperature,
206
+ top_p=request.top_p,
207
+ stop=request.stop
208
+ )
209
+
210
+ # Split response into chunks
211
+ words = response_content.split()
212
+ for i, word in enumerate(words):
213
+ chunk = ChatStreamChunk(
214
+ content=word + " " if i < len(words) - 1 else word,
215
+ finish_reason=None
216
+ )
217
+ yield f"data: {json.dumps(chunk.dict())}\n\n"
218
+
219
+ # Final chunk with usage info
220
+ final_chunk = ChatStreamChunk(
221
+ content="",
222
+ finish_reason="stop",
223
+ usage=usage
224
+ )
225
+ yield f"data: {json.dumps(final_chunk.dict())}\n\n"
226
+ yield "data: [DONE]\n\n"
227
+
228
+ return StreamingResponse(
229
+ generate_stream(),
230
+ media_type="text/plain",
231
+ headers={"Cache-Control": "no-cache"}
232
+ )
233
+
234
+ except Exception as e:
235
+ logger.error(f"Error in streaming: {e}")
236
+ raise HTTPException(status_code=500, detail=str(e))
237
+
238
+ if __name__ == "__main__":
239
+ uvicorn.run(app, host="0.0.0.0", port=7860)