| |
| """ |
| AI Chat Application - Pure FastAPI Backend |
| Serves custom frontend with OpenAI compatible API |
| """ |
|
|
| import os |
| import sys |
| import json |
| import logging |
| import time |
| from typing import Optional, Dict, Any, Generator, List |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from fastapi import FastAPI, HTTPException, Response |
| from fastapi.responses import StreamingResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.cors import CORSMiddleware |
| import asyncio |
| import threading |
| from threading import Thread |
| from pydantic import BaseModel |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| class ChatMessage(BaseModel): |
| role: str |
| content: str |
|
|
| class ChatRequest(BaseModel): |
| messages: List[ChatMessage] |
| model: Optional[str] = "qwen-coder-3-30b" |
| temperature: Optional[float] = 0.7 |
| max_tokens: Optional[int] = 2048 |
| stream: Optional[bool] = False |
|
|
| class ChatResponse(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[Dict[str, Any]] |
|
|
| |
| tokenizer = None |
| model = None |
| current_model_name = None |
| available_models = { |
| "qwen-coder-3-30b": "Qwen/Qwen3-Coder-30B-A3B-Instruct", |
| "qwen-4b-thinking": "Qwen/Qwen3-4B-Thinking-2507" |
| } |
|
|
|
|
| def load_model(model_id: str = "qwen-coder-3-30b"): |
| """Load the specified Qwen model and tokenizer""" |
| global tokenizer, model, current_model_name |
| |
| try: |
| if model_id not in available_models: |
| raise ValueError(f"Unknown model ID: {model_id}") |
| |
| model_name = available_models[model_id] |
| |
| |
| if current_model_name == model_name: |
| logger.info(f"Model {model_name} is already loaded") |
| return |
| |
| |
| if model is not None: |
| del model |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| logger.info(f"Loading model: {model_name}") |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| |
| |
| if model_id == "qwen-4b-thinking": |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True, |
| low_cpu_mem_usage=True |
| ) |
| else: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| |
| current_model_name = model_name |
| logger.info(f"Model {model_name} loaded successfully") |
| |
| except Exception as e: |
| logger.error(f"Error loading model {model_id}: {e}") |
| |
| logger.warning("Using fallback model response") |
| def generate_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048, model_id: str = "qwen-coder-3-30b"): |
| """Generate response from the model""" |
| try: |
| |
| if model is None or current_model_name != available_models.get(model_id): |
| load_model(model_id) |
| |
| if model is None or tokenizer is None: |
| |
| return f"I'm a Qwen AI assistant ({model_id}). The model is currently loading, please try again in a moment." |
| |
| |
| formatted_messages = [] |
| for msg in messages: |
| formatted_messages.append({"role": msg.role, "content": msg.content}) |
| |
| |
| text = tokenizer.apply_chat_template( |
| formatted_messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| |
| |
| if model_id == "qwen-4b-thinking": |
| |
| max_tokens = min(max_tokens, 1024) |
| temperature = min(temperature, 0.8) |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| |
| |
| response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| return response.strip() |
| |
| except Exception as e: |
| logger.error(f"Error generating response: {e}") |
| return f"I apologize, but I encountered an error while processing your request: {str(e)}" |
|
|
| def generate_streaming_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048, model_id: str = "qwen-coder-3-30b"): |
| """Generate streaming response from the model""" |
| try: |
| |
| if model is None or current_model_name != available_models.get(model_id): |
| load_model(model_id) |
| |
| if model is None or tokenizer is None: |
| |
| response = f"I'm a Qwen AI assistant ({model_id}). The model is currently loading, please try again in a moment." |
| for char in response: |
| yield f"data: {json.dumps({'choices': [{'delta': {'content': char}}]})}\n\n" |
| time.sleep(0.05) |
| yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
| yield "data: [DONE]\n\n" |
| return |
| |
| |
| formatted_messages = [] |
| for msg in messages: |
| formatted_messages.append({"role": msg.role, "content": msg.content}) |
| |
| |
| text = tokenizer.apply_chat_template( |
| formatted_messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| |
| |
| if model_id == "qwen-4b-thinking": |
| max_tokens = min(max_tokens, 1024) |
| temperature = min(temperature, 0.8) |
| |
| |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| |
| generation_kwargs = { |
| **inputs, |
| "max_new_tokens": max_tokens, |
| "temperature": temperature, |
| "do_sample": True, |
| "pad_token_id": tokenizer.eos_token_id, |
| "streamer": streamer |
| } |
| |
| |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| |
| for new_text in streamer: |
| if new_text: |
| yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n" |
| |
| yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
| yield "data: [DONE]\n\n" |
| |
| except Exception as e: |
| logger.error(f"Error in streaming generation: {e}") |
| error_msg = f"Error: {str(e)}" |
| yield f"data: {json.dumps({'choices': [{'delta': {'content': error_msg}}]})}\n\n" |
| yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
| yield "data: [DONE]\n\n" |
| |
| def generate_plain_text_stream(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048, model_id: str = "qwen-coder-3-30b"): |
| """Plain text streaming generator used by /chat compatibility endpoint (no SSE).""" |
| try: |
| |
| if model is None or current_model_name != available_models.get(model_id): |
| load_model(model_id) |
| |
| if model is None or tokenizer is None: |
| |
| response = f"I'm a Qwen AI assistant ({model_id}). The model is currently loading, please try again in a moment." |
| for ch in response: |
| yield ch |
| time.sleep(0.02) |
| return |
| |
| |
| formatted_messages = [{"role": m.role, "content": m.content} for m in messages] |
| |
| |
| text = tokenizer.apply_chat_template( |
| formatted_messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| |
| |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) |
| |
| |
| if model_id == "qwen-4b-thinking": |
| max_tokens = min(max_tokens, 1024) |
| temperature = min(temperature, 0.8) |
| |
| |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| generation_kwargs = { |
| **inputs, |
| "max_new_tokens": max_tokens, |
| "temperature": temperature, |
| "do_sample": True, |
| "pad_token_id": tokenizer.eos_token_id, |
| "streamer": streamer |
| } |
| |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
| |
| for new_text in streamer: |
| if new_text: |
| yield new_text |
| except Exception as e: |
| logger.error(f"Error in plain streaming generation: {e}") |
| yield f"[error] {str(e)}" |
| |
| |
| app = FastAPI(title="AI Chat API", description="OpenAI compatible interface for Qwen model") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| @app.get("/") |
| async def serve_index(): |
| """Serve the main HTML file""" |
| return FileResponse("public/index.html") |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return {"status": "healthy", "model_loaded": model is not None} |
|
|
| @app.get("/ping") |
| async def ping(): |
| """Simple ping endpoint""" |
| return {"status": "pong"} |
|
|
| @app.head("/ping") |
| async def ping_head(): |
| """HEAD ping for health checks""" |
| return Response(status_code=200) |
|
|
| @app.get("/api/models") |
| async def list_models(): |
| """List available models""" |
| return { |
| "data": [ |
| { |
| "id": "qwen-coder-3-30b", |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "qwen", |
| "name": "Qwen 3 Coder 30B", |
| "description": "Výkonný model pro programování" |
| }, |
| { |
| "id": "qwen-4b-thinking", |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "qwen", |
| "name": "Qwen 4B Thinking", |
| "description": "Rychlejší odlehčený model" |
| } |
| ] |
| } |
|
|
| @app.post("/api/chat") |
| async def chat_completion(request: ChatRequest): |
| """OpenAI compatible chat completion endpoint""" |
| try: |
| model_id = request.model or "qwen-coder-3-30b" |
| |
| |
| if model_id not in available_models: |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") |
| |
| if request.stream: |
| return StreamingResponse( |
| generate_streaming_response( |
| request.messages, |
| request.temperature or 0.7, |
| request.max_tokens or 2048, |
| model_id |
| ), |
| media_type="text/plain" |
| ) |
| else: |
| response_content = generate_response( |
| request.messages, |
| request.temperature or 0.7, |
| request.max_tokens or 2048, |
| model_id |
| ) |
| |
| return ChatResponse( |
| id=f"chatcmpl-{int(time.time())}", |
| created=int(time.time()), |
| model=model_id, |
| choices=[{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": response_content |
| }, |
| "finish_reason": "stop" |
| }] |
| ) |
| |
| except Exception as e: |
| logger.error(f"Error in chat completion: {e}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.post("/v1/chat/completions") |
| async def openai_chat_completion(request: ChatRequest): |
| """OpenAI API compatible endpoint""" |
| return await chat_completion(request) |
|
|
| @app.post("/chat") |
| async def chat_stream_compat(payload: Dict[str, Any]): |
| """Compatibility endpoint for frontend streaming /chat (plain text stream).""" |
| try: |
| message = str(payload.get("message", "") or "").strip() |
| history_raw = payload.get("history", []) or [] |
| model_id = payload.get("model", "qwen-coder-3-30b") |
| |
| |
| if model_id not in available_models: |
| model_id = "qwen-coder-3-30b" |
|
|
| history_msgs: List[ChatMessage] = [] |
| for item in history_raw: |
| role = item.get("role") |
| content = item.get("content") |
| if role and content is not None: |
| history_msgs.append(ChatMessage(role=role, content=str(content))) |
|
|
| if message: |
| history_msgs.append(ChatMessage(role="user", content=message)) |
|
|
| return StreamingResponse( |
| generate_plain_text_stream( |
| history_msgs, |
| temperature=0.7, |
| max_tokens=2048, |
| model_id=model_id |
| ), |
| media_type="text/plain; charset=utf-8" |
| ) |
| except Exception as e: |
| logger.error(f"/chat compatibility error: {e}") |
| raise HTTPException(status_code=400, detail="Invalid request body") |
|
|
| |
| app.mount("/", StaticFiles(directory="public", html=True), name="static") |
|
|
| |
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize the default model on startup""" |
| |
| thread = Thread(target=load_model, args=("qwen-coder-3-30b",)) |
| thread.daemon = True |
| thread.start() |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| |
| |
| port = int(os.environ.get("PORT", 7860)) |
| |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=port, |
| access_log=True |
| ) |