Mychatmodel / app.py
1MR's picture
Update app.py
eb34210 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn
import logging
from contextlib import asynccontextmanager
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
# Request/Response models
class ChatMessage(BaseModel):
role: str # "system", "user", "assistant"
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
stop: Optional[List[str]] = None
class ChatResponse(BaseModel):
content: str
finish_reason: str
usage: Dict[str, int]
class ChatStreamChunk(BaseModel):
content: str
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load model on startup
global model, tokenizer
logger.info("Loading model and tokenizer...")
# SOLUTION 1: Use a more compatible model
# Replace Qwen3-4B with a widely supported model
# model_name = "microsoft/DialoGPT-medium" # Alternative: "gpt2", "microsoft/DialoGPT-small"
model_name = "Qwen/Qwen2.5-7B-Instruct" # Alternative: "gpt2", "microsoft/DialoGPT-small"
# SOLUTION 2: If you want to use Qwen models, try these alternatives:
# model_name = "Qwen/Qwen1.5-0.5B-Chat" # Smaller, more compatible Qwen model
# model_name = "Qwen/Qwen2-0.5B-Instruct" # Even smaller option
try:
# SOLUTION 3: Add trust_remote_code=True and use_fast=False for better compatibility
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=False # Use slow tokenizer for better compatibility
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Set pad token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Model loaded successfully: {model_name}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# SOLUTION 4: Fallback to a guaranteed working model
logger.info("Attempting fallback to GPT-2...")
try:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Fallback model loaded successfully: {model_name}")
except Exception as fallback_error:
logger.error(f"Fallback model also failed: {fallback_error}")
raise fallback_error
yield
# Cleanup
logger.info("Shutting down...")
# Initialize FastAPI app
app = FastAPI(
title="Custom Chat Model API",
description="API for fine-tuned chat model",
version="1.0.0",
lifespan=lifespan
)
def format_messages(messages: List[ChatMessage]) -> str:
"""Format messages into a prompt string"""
formatted_prompt = ""
for message in messages:
if message.role == "system":
formatted_prompt += f"System: {message.content}\n"
elif message.role == "user":
formatted_prompt += f"User: {message.content}\n"
elif message.role == "assistant":
formatted_prompt += f"Assistant: {message.content}\n"
# Add assistant prompt for completion
formatted_prompt += "Assistant:"
return formatted_prompt
def generate_response(
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
stop: Optional[List[str]] = None
) -> tuple[str, Dict[str, int]]:
"""Generate response using the loaded model"""
# Handle device placement more robustly
device = next(model.parameters()).device
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
input_length = input_ids.shape[1]
# Generate response
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Decode only the generated part
generated_ids = outputs[0][input_length:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Handle stop tokens
if stop:
for stop_token in stop:
if stop_token in response:
response = response.split(stop_token)[0]
break
# Calculate tokens
output_tokens = len(tokenizer.encode(response))
usage = {
"input_tokens": input_length,
"output_tokens": output_tokens,
"total_tokens": input_length + output_tokens
}
return response.strip(), usage
@app.get("/")
async def root():
return {"message": "Custom Chat Model API", "status": "running"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
"""Main chat completion endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Format messages into prompt
prompt = format_messages(request.messages)
# Generate response
response_content, usage = generate_response(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop
)
return ChatResponse(
content=response_content,
finish_reason="stop",
usage=usage
)
except Exception as e:
logger.error(f"Error in chat completion: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
"""Streaming chat completion endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
from fastapi.responses import StreamingResponse
import json
def generate_stream():
prompt = format_messages(request.messages)
# For simplicity, we'll simulate streaming by chunking the response
# In a real implementation, you'd use model.generate with streaming
response_content, usage = generate_response(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop
)
# Split response into chunks
words = response_content.split()
for i, word in enumerate(words):
chunk = ChatStreamChunk(
content=word + " " if i < len(words) - 1 else word,
finish_reason=None
)
yield f"data: {json.dumps(chunk.dict())}\n\n"
# Final chunk with usage info
final_chunk = ChatStreamChunk(
content="",
finish_reason="stop",
usage=usage
)
yield f"data: {json.dumps(final_chunk.dict())}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={"Cache-Control": "no-cache"}
)
except Exception as e:
logger.error(f"Error in streaming: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)