Spaces:
Sleeping
Sleeping
File size: 9,022 Bytes
f4c2faa bcca643 eb34210 bcca643 f4c2faa bcca643 f4c2faa bcca643 f4c2faa bcca643 f4c2faa bcca643 f4c2faa bcca643 f4c2faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn
import logging
from contextlib import asynccontextmanager
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
# Request/Response models
class ChatMessage(BaseModel):
role: str # "system", "user", "assistant"
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
stop: Optional[List[str]] = None
class ChatResponse(BaseModel):
content: str
finish_reason: str
usage: Dict[str, int]
class ChatStreamChunk(BaseModel):
content: str
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load model on startup
global model, tokenizer
logger.info("Loading model and tokenizer...")
# SOLUTION 1: Use a more compatible model
# Replace Qwen3-4B with a widely supported model
# model_name = "microsoft/DialoGPT-medium" # Alternative: "gpt2", "microsoft/DialoGPT-small"
model_name = "Qwen/Qwen2.5-7B-Instruct" # Alternative: "gpt2", "microsoft/DialoGPT-small"
# SOLUTION 2: If you want to use Qwen models, try these alternatives:
# model_name = "Qwen/Qwen1.5-0.5B-Chat" # Smaller, more compatible Qwen model
# model_name = "Qwen/Qwen2-0.5B-Instruct" # Even smaller option
try:
# SOLUTION 3: Add trust_remote_code=True and use_fast=False for better compatibility
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=False # Use slow tokenizer for better compatibility
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Set pad token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Model loaded successfully: {model_name}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# SOLUTION 4: Fallback to a guaranteed working model
logger.info("Attempting fallback to GPT-2...")
try:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Fallback model loaded successfully: {model_name}")
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
raise fallback_error
yield
# Cleanup
logger.info("Shutting down...")
# Initialize FastAPI app
app = FastAPI(
title="Custom Chat Model API",
description="API for fine-tuned chat model",
version="1.0.0",
lifespan=lifespan
)
def format_messages(messages: List[ChatMessage]) -> str:
"""Format messages into a prompt string"""
formatted_prompt = ""
for message in messages:
if message.role == "system":
formatted_prompt += f"System: {message.content}\n"
elif message.role == "user":
formatted_prompt += f"User: {message.content}\n"
elif message.role == "assistant":
formatted_prompt += f"Assistant: {message.content}\n"
# Add assistant prompt for completion
formatted_prompt += "Assistant:"
return formatted_prompt
def generate_response(
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
stop: Optional[List[str]] = None
) -> tuple[str, Dict[str, int]]:
"""Generate response using the loaded model"""
# Handle device placement more robustly
device = next(model.parameters()).device
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
input_length = input_ids.shape[1]
# Generate response
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Decode only the generated part
generated_ids = outputs[0][input_length:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Handle stop tokens
if stop:
for stop_token in stop:
if stop_token in response:
response = response.split(stop_token)[0]
break
# Calculate tokens
output_tokens = len(tokenizer.encode(response))
usage = {
"input_tokens": input_length,
"output_tokens": output_tokens,
"total_tokens": input_length + output_tokens
}
return response.strip(), usage
@app.get("/")
async def root():
return {"message": "Custom Chat Model API", "status": "running"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""Main chat completion endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Format messages into prompt
prompt = format_messages(request.messages)
# Generate response
response_content, usage = generate_response(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop
)
return ChatResponse(
content=response_content,
finish_reason="stop",
usage=usage
)
except Exception as e:
logger.error(f"Error in chat completion: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""Streaming chat completion endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
from fastapi.responses import StreamingResponse
import json
def generate_stream():
prompt = format_messages(request.messages)
# For simplicity, we'll simulate streaming by chunking the response
# In a real implementation, you'd use model.generate with streaming
response_content, usage = generate_response(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop
)
# Split response into chunks
words = response_content.split()
for i, word in enumerate(words):
chunk = ChatStreamChunk(
content=word + " " if i < len(words) - 1 else word,
finish_reason=None
)
yield f"data: {json.dumps(chunk.dict())}\n\n"
# Final chunk with usage info
final_chunk = ChatStreamChunk(
content="",
finish_reason="stop",
usage=usage
)
yield f"data: {json.dumps(final_chunk.dict())}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={"Cache-Control": "no-cache"}
)
except Exception as e:
logger.error(f"Error in streaming: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |