Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import uvicorn | |
| import logging | |
| from contextlib import asynccontextmanager | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| # Request/Response models | |
| class ChatMessage(BaseModel): | |
| role: str # "system", "user", "assistant" | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| max_tokens: Optional[int] = 512 | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.9 | |
| stop: Optional[List[str]] = None | |
| class ChatResponse(BaseModel): | |
| content: str | |
| finish_reason: str | |
| usage: Dict[str, int] | |
| class ChatStreamChunk(BaseModel): | |
| content: str | |
| finish_reason: Optional[str] = None | |
| usage: Optional[Dict[str, int]] = None | |
| async def lifespan(app: FastAPI): | |
| # Load model on startup | |
| global model, tokenizer | |
| logger.info("Loading model and tokenizer...") | |
| # SOLUTION 1: Use a more compatible model | |
| # Replace Qwen3-4B with a widely supported model | |
| # model_name = "microsoft/DialoGPT-medium" # Alternative: "gpt2", "microsoft/DialoGPT-small" | |
| model_name = "Qwen/Qwen2.5-7B-Instruct" # Alternative: "gpt2", "microsoft/DialoGPT-small" | |
| # SOLUTION 2: If you want to use Qwen models, try these alternatives: | |
| # model_name = "Qwen/Qwen1.5-0.5B-Chat" # Smaller, more compatible Qwen model | |
| # model_name = "Qwen/Qwen2-0.5B-Instruct" # Even smaller option | |
| try: | |
| # SOLUTION 3: Add trust_remote_code=True and use_fast=False for better compatibility | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_fast=False # Use slow tokenizer for better compatibility | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True | |
| ) | |
| # Set pad token if not present | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logger.info(f"Model loaded successfully: {model_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| # SOLUTION 4: Fallback to a guaranteed working model | |
| logger.info("Attempting fallback to GPT-2...") | |
| try: | |
| model_name = "gpt2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logger.info(f"Fallback model loaded successfully: {model_name}") | |
| except Exception as fallback_error: | |
| logger.error(f"Fallback model also failed: {fallback_error}") | |
| raise fallback_error | |
| yield | |
| # Cleanup | |
| logger.info("Shutting down...") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Custom Chat Model API", | |
| description="API for fine-tuned chat model", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| def format_messages(messages: List[ChatMessage]) -> str: | |
| """Format messages into a prompt string""" | |
| formatted_prompt = "" | |
| for message in messages: | |
| if message.role == "system": | |
| formatted_prompt += f"System: {message.content}\n" | |
| elif message.role == "user": | |
| formatted_prompt += f"User: {message.content}\n" | |
| elif message.role == "assistant": | |
| formatted_prompt += f"Assistant: {message.content}\n" | |
| # Add assistant prompt for completion | |
| formatted_prompt += "Assistant:" | |
| return formatted_prompt | |
| def generate_response( | |
| prompt: str, | |
| max_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| stop: Optional[List[str]] = None | |
| ) -> tuple[str, Dict[str, int]]: | |
| """Generate response using the loaded model""" | |
| # Handle device placement more robustly | |
| device = next(model.parameters()).device | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| input_length = input_ids.shape[1] | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1 | |
| ) | |
| # Decode only the generated part | |
| generated_ids = outputs[0][input_length:] | |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Handle stop tokens | |
| if stop: | |
| for stop_token in stop: | |
| if stop_token in response: | |
| response = response.split(stop_token)[0] | |
| break | |
| # Calculate tokens | |
| output_tokens = len(tokenizer.encode(response)) | |
| usage = { | |
| "input_tokens": input_length, | |
| "output_tokens": output_tokens, | |
| "total_tokens": input_length + output_tokens | |
| } | |
| return response.strip(), usage | |
| async def root(): | |
| return {"message": "Custom Chat Model API", "status": "running"} | |
| async def health_check(): | |
| return {"status": "healthy", "model_loaded": model is not None} | |
| async def chat_completions(request: ChatRequest): | |
| """Main chat completion endpoint""" | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Format messages into prompt | |
| prompt = format_messages(request.messages) | |
| # Generate response | |
| response_content, usage = generate_response( | |
| prompt=prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stop=request.stop | |
| ) | |
| return ChatResponse( | |
| content=response_content, | |
| finish_reason="stop", | |
| usage=usage | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat completion: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat_stream(request: ChatRequest): | |
| """Streaming chat completion endpoint""" | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| def generate_stream(): | |
| prompt = format_messages(request.messages) | |
| # For simplicity, we'll simulate streaming by chunking the response | |
| # In a real implementation, you'd use model.generate with streaming | |
| response_content, usage = generate_response( | |
| prompt=prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stop=request.stop | |
| ) | |
| # Split response into chunks | |
| words = response_content.split() | |
| for i, word in enumerate(words): | |
| chunk = ChatStreamChunk( | |
| content=word + " " if i < len(words) - 1 else word, | |
| finish_reason=None | |
| ) | |
| yield f"data: {json.dumps(chunk.dict())}\n\n" | |
| # Final chunk with usage info | |
| final_chunk = ChatStreamChunk( | |
| content="", | |
| finish_reason="stop", | |
| usage=usage | |
| ) | |
| yield f"data: {json.dumps(final_chunk.dict())}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/plain", | |
| headers={"Cache-Control": "no-cache"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in streaming: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |