Spaces:
Runtime error
Runtime error
Update Gradio app with multiple files
Browse files- model_server.py +285 -0
- requirements.txt +2 -0
- run_client.py +144 -0
- terminal_chatbot.py +297 -0
- updated_app.py +271 -0
- updated_models.py +213 -0
model_server.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AI Coding Model Server
|
| 4 |
+
FastAPI server that hosts the 5B parameter coding model
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import spaces
|
| 9 |
+
import uvicorn
|
| 10 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from pydantic import BaseModel, Field
|
| 13 |
+
from typing import List, Dict, Any, Optional
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import asyncio
|
| 17 |
+
import threading
|
| 18 |
+
from contextlib import asynccontextmanager
|
| 19 |
+
|
| 20 |
+
# Import model components
|
| 21 |
+
from models import CodeModel
|
| 22 |
+
from utils import format_code_response, validate_code_syntax
|
| 23 |
+
|
| 24 |
+
# Configure logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# Global model instance
|
| 29 |
+
code_model = None
|
| 30 |
+
model_loading = False
|
| 31 |
+
|
| 32 |
+
class ChatMessage(BaseModel):
|
| 33 |
+
"""Chat message model."""
|
| 34 |
+
message: str = Field(..., description="User's message")
|
| 35 |
+
history: List[Dict[str, str]] = Field(default_factory=list, description="Chat history")
|
| 36 |
+
language: str = Field(default="python", description="Target programming language")
|
| 37 |
+
temperature: float = Field(default=0.7, ge=0.1, le=1.0, description="Generation temperature")
|
| 38 |
+
|
| 39 |
+
class ChatResponse(BaseModel):
|
| 40 |
+
"""Chat response model."""
|
| 41 |
+
choices: List[Dict[str, Dict[str, str]]] = Field(..., description="Generated responses")
|
| 42 |
+
history: List[Dict[str, str]] = Field(..., description="Updated chat history")
|
| 43 |
+
usage: Optional[Dict[str, int]] = Field(None, description="Token usage information")
|
| 44 |
+
|
| 45 |
+
class HealthResponse(BaseModel):
|
| 46 |
+
"""Health check response."""
|
| 47 |
+
status: str
|
| 48 |
+
model_loaded: bool
|
| 49 |
+
model_name: str
|
| 50 |
+
device: str
|
| 51 |
+
memory_usage: Optional[Dict[str, Any]] = None
|
| 52 |
+
|
| 53 |
+
class ModelInfoResponse(BaseModel):
|
| 54 |
+
"""Model information response."""
|
| 55 |
+
model_name: str
|
| 56 |
+
parameter_count: str
|
| 57 |
+
max_length: int
|
| 58 |
+
device: str
|
| 59 |
+
is_loaded: bool
|
| 60 |
+
vocab_size: int
|
| 61 |
+
|
| 62 |
+
@asynccontextmanager
|
| 63 |
+
async def lifespan(app: FastAPI):
|
| 64 |
+
"""Application lifespan management."""
|
| 65 |
+
# Startup
|
| 66 |
+
logger.info("Starting up AI Coding Model Server...")
|
| 67 |
+
await load_model()
|
| 68 |
+
|
| 69 |
+
yield
|
| 70 |
+
|
| 71 |
+
# Shutdown
|
| 72 |
+
logger.info("Shutting down server...")
|
| 73 |
+
|
| 74 |
+
async def load_model():
|
| 75 |
+
"""Load the model in background."""
|
| 76 |
+
global code_model, model_loading
|
| 77 |
+
|
| 78 |
+
if code_model is not None or model_loading:
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
model_loading = True
|
| 82 |
+
logger.info("Loading coding model...")
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Load model in thread to avoid blocking
|
| 86 |
+
loop = asyncio.get_event_loop()
|
| 87 |
+
code_model = await loop.run_in_executor(None, CodeModel)
|
| 88 |
+
|
| 89 |
+
if code_model.is_loaded:
|
| 90 |
+
logger.info(f"✅ Model loaded successfully: {code_model.model_name}")
|
| 91 |
+
else:
|
| 92 |
+
logger.error("❌ Failed to load model")
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"❌ Error loading model: {e}")
|
| 96 |
+
code_model = None
|
| 97 |
+
finally:
|
| 98 |
+
model_loading = False
|
| 99 |
+
|
| 100 |
+
def create_app() -> FastAPI:
|
| 101 |
+
"""Create and configure the FastAPI application."""
|
| 102 |
+
|
| 103 |
+
# Create FastAPI app with lifespan management
|
| 104 |
+
app = FastAPI(
|
| 105 |
+
title="AI Coding Model Server",
|
| 106 |
+
description="FastAPI server hosting a 5B parameter coding model",
|
| 107 |
+
version="1.0.0",
|
| 108 |
+
lifespan=lifespan
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Add CORS middleware
|
| 112 |
+
app.add_middleware(
|
| 113 |
+
CORSMiddleware,
|
| 114 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 115 |
+
allow_credentials=True,
|
| 116 |
+
allow_methods=["*"],
|
| 117 |
+
allow_headers=["*"],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
@app.get("/", response_model=Dict[str, str])
|
| 121 |
+
async def root():
|
| 122 |
+
"""Root endpoint."""
|
| 123 |
+
return {
|
| 124 |
+
"message": "AI Coding Model Server",
|
| 125 |
+
"version": "1.0.0",
|
| 126 |
+
"status": "running" if code_model and code_model.is_loaded else "loading"
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
@app.get("/health", response_model=HealthResponse)
|
| 130 |
+
async def health_check():
|
| 131 |
+
"""Health check endpoint."""
|
| 132 |
+
if model_loading:
|
| 133 |
+
return HealthResponse(
|
| 134 |
+
status="loading",
|
| 135 |
+
model_loaded=False,
|
| 136 |
+
model_name="Loading...",
|
| 137 |
+
device="unknown"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if not code_model or not code_model.is_loaded:
|
| 141 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 142 |
+
|
| 143 |
+
# Get memory usage if available
|
| 144 |
+
memory_info = None
|
| 145 |
+
if torch.cuda.is_available():
|
| 146 |
+
memory_info = {
|
| 147 |
+
"allocated": torch.cuda.memory_allocated() / 1024**3, # GB
|
| 148 |
+
"cached": torch.cuda.memory_reserved() / 1024**3, # GB
|
| 149 |
+
"total": torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
return HealthResponse(
|
| 153 |
+
status="healthy",
|
| 154 |
+
model_loaded=True,
|
| 155 |
+
model_name=code_model.model_name,
|
| 156 |
+
device=code_model.device,
|
| 157 |
+
memory_usage=memory_info
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
@app.get("/model/info", response_model=ModelInfoResponse)
|
| 161 |
+
async def model_info():
|
| 162 |
+
"""Get detailed model information."""
|
| 163 |
+
if not code_model:
|
| 164 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 165 |
+
|
| 166 |
+
info = code_model.get_model_info()
|
| 167 |
+
return ModelInfoResponse(**info)
|
| 168 |
+
|
| 169 |
+
@app.post("/api/chat", response_model=ChatResponse)
|
| 170 |
+
async def chat(request: ChatMessage):
|
| 171 |
+
"""Main chat endpoint."""
|
| 172 |
+
if model_loading:
|
| 173 |
+
raise HTTPException(status_code=503, detail="Model is still loading")
|
| 174 |
+
|
| 175 |
+
if not code_model or not code_model.is_loaded:
|
| 176 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
# Generate response using the model
|
| 180 |
+
messages = request.history.copy()
|
| 181 |
+
messages.append({"role": "user", "content": request.message})
|
| 182 |
+
|
| 183 |
+
response_text = code_model.generate(
|
| 184 |
+
messages=messages,
|
| 185 |
+
temperature=request.temperature,
|
| 186 |
+
max_new_tokens=2048,
|
| 187 |
+
language=request.language
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Format the response
|
| 191 |
+
formatted_response = format_code_response(response_text)
|
| 192 |
+
|
| 193 |
+
# Update chat history
|
| 194 |
+
new_history = request.history.copy()
|
| 195 |
+
new_history.append({"role": "user", "content": request.message})
|
| 196 |
+
new_history.append({"role": "assistant", "content": formatted_response})
|
| 197 |
+
|
| 198 |
+
return ChatResponse(
|
| 199 |
+
choices=[{"message": {"content": formatted_response}}],
|
| 200 |
+
history=new_history
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(f"Chat error: {e}")
|
| 205 |
+
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
|
| 206 |
+
|
| 207 |
+
@app.post("/api/validate-code")
|
| 208 |
+
async def validate_code(request: Dict[str, Any]):
|
| 209 |
+
"""Validate code syntax."""
|
| 210 |
+
code = request.get("code", "")
|
| 211 |
+
language = request.get("language", "python")
|
| 212 |
+
|
| 213 |
+
if not code:
|
| 214 |
+
raise HTTPException(status_code=400, detail="No code provided")
|
| 215 |
+
|
| 216 |
+
validation_result = validate_code_syntax(code, language)
|
| 217 |
+
return validation_result
|
| 218 |
+
|
| 219 |
+
@app.get("/api/languages")
|
| 220 |
+
async def get_supported_languages():
|
| 221 |
+
"""Get list of supported programming languages."""
|
| 222 |
+
return {
|
| 223 |
+
"languages": [
|
| 224 |
+
"python", "javascript", "java", "cpp", "c", "go", "rust",
|
| 225 |
+
"typescript", "php", "ruby", "swift", "kotlin", "sql",
|
| 226 |
+
"html", "css", "bash", "powershell"
|
| 227 |
+
]
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
return app
|
| 231 |
+
|
| 232 |
+
def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False):
|
| 233 |
+
"""Run the FastAPI server."""
|
| 234 |
+
app = create_app()
|
| 235 |
+
|
| 236 |
+
console_info = f"""
|
| 237 |
+
🚀 AI Coding Model Server Starting...
|
| 238 |
+
|
| 239 |
+
📊 Server Info:
|
| 240 |
+
• Host: {host}
|
| 241 |
+
• Port: {port}
|
| 242 |
+
• Model: Loading...
|
| 243 |
+
• Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}
|
| 244 |
+
|
| 245 |
+
🔗 Endpoints:
|
| 246 |
+
• Health: http://{host}:{port}/health
|
| 247 |
+
• Model Info: http://{host}:{port}/model/info
|
| 248 |
+
• Chat: http://{host}:{port}/api/chat
|
| 249 |
+
• API Docs: http://{host}:{port}/docs
|
| 250 |
+
|
| 251 |
+
💡 Usage:
|
| 252 |
+
• Terminal client: python terminal_chatbot.py
|
| 253 |
+
• API calls: POST to /api/chat with chat messages
|
| 254 |
+
• Check status: GET /health
|
| 255 |
+
|
| 256 |
+
⚡ Server is ready! Press Ctrl+C to stop.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
print(console_info)
|
| 260 |
+
|
| 261 |
+
# Run server
|
| 262 |
+
uvicorn.run(
|
| 263 |
+
"model_server:create_app",
|
| 264 |
+
host=host,
|
| 265 |
+
port=port,
|
| 266 |
+
reload=reload,
|
| 267 |
+
log_level="info",
|
| 268 |
+
access_log=True
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
import argparse
|
| 273 |
+
|
| 274 |
+
parser = argparse.ArgumentParser(description="AI Coding Model Server")
|
| 275 |
+
parser.add_argument("--host", default="0.0.0.0", help="Server host")
|
| 276 |
+
parser.add_argument("--port", type=int, default=8000, help="Server port")
|
| 277 |
+
parser.add_argument("--reload", action="store_true", help="Auto-reload on changes")
|
| 278 |
+
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
|
| 281 |
+
run_server(
|
| 282 |
+
host=args.host,
|
| 283 |
+
port=args.port,
|
| 284 |
+
reload=args.reload
|
| 285 |
+
)
|
requirements.txt
CHANGED
|
@@ -19,3 +19,5 @@ matplotlib
|
|
| 19 |
seaborn
|
| 20 |
jupyter
|
| 21 |
ipywidgets
|
|
|
|
|
|
|
|
|
| 19 |
seaborn
|
| 20 |
jupyter
|
| 21 |
ipywidgets
|
| 22 |
+
rich
|
| 23 |
+
pydantic
|
run_client.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Easy launcher for the terminal chatbot
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
print("🤖 AI Coding Assistant Terminal Launcher")
|
| 12 |
+
print("=" * 50)
|
| 13 |
+
|
| 14 |
+
# Check if model_server.py exists
|
| 15 |
+
if not os.path.exists("model_server.py"):
|
| 16 |
+
print("❌ model_server.py not found!")
|
| 17 |
+
print("Make sure all files are in the same directory.")
|
| 18 |
+
sys.exit(1)
|
| 19 |
+
|
| 20 |
+
# Check if terminal_chatbot.py exists
|
| 21 |
+
if not os.path.exists("terminal_chatbot.py"):
|
| 22 |
+
print("❌ terminal_chatbot.py not found!")
|
| 23 |
+
print("Make sure all files are in the same directory.")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
|
| 26 |
+
print("📋 Files found:")
|
| 27 |
+
print(" ✅ model_server.py")
|
| 28 |
+
print(" ✅ terminal_chatbot.py")
|
| 29 |
+
print(" ✅ models.py")
|
| 30 |
+
print(" ✅ utils.py")
|
| 31 |
+
print()
|
| 32 |
+
|
| 33 |
+
# Ask user what they want to run
|
| 34 |
+
print("What would you like to run?")
|
| 35 |
+
print("1. Start model server (required for chatbot)")
|
| 36 |
+
print("2. Start terminal chatbot (requires running server)")
|
| 37 |
+
print("3. Start both (server in background, then chatbot)")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
choice = input("\nEnter your choice (1-3): ").strip()
|
| 41 |
+
except KeyboardInterrupt:
|
| 42 |
+
print("\n👋 Goodbye!")
|
| 43 |
+
sys.exit(0)
|
| 44 |
+
|
| 45 |
+
if choice == "1":
|
| 46 |
+
print("\n🚀 Starting model server...")
|
| 47 |
+
print("💡 Server will run on http://localhost:8000")
|
| 48 |
+
print("💡 Press Ctrl+C to stop")
|
| 49 |
+
try:
|
| 50 |
+
subprocess.run([sys.executable, "model_server.py"])
|
| 51 |
+
except KeyboardInterrupt:
|
| 52 |
+
print("\n🛑 Server stopped")
|
| 53 |
+
|
| 54 |
+
elif choice == "2":
|
| 55 |
+
print("\n🤖 Starting terminal chatbot...")
|
| 56 |
+
print("💡 Make sure the server is running first!")
|
| 57 |
+
print("💡 If you get connection errors, run option 1 first")
|
| 58 |
+
try:
|
| 59 |
+
subprocess.run([sys.executable, "terminal_chatbot.py"])
|
| 60 |
+
except KeyboardInterrupt:
|
| 61 |
+
print("\n👋 Chatbot stopped")
|
| 62 |
+
|
| 63 |
+
elif choice == "3":
|
| 64 |
+
print("\n🚀 Starting server in background...")
|
| 65 |
+
print("💡 Server will run on http://localhost:8000")
|
| 66 |
+
|
| 67 |
+
# Start server in background
|
| 68 |
+
server_process = subprocess.Popen([sys.executable, "model_server.py"])
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Wait a bit for server to start
|
| 72 |
+
print("⏳ Waiting for server to start...")
|
| 73 |
+
import time
|
| 74 |
+
time.sleep(5)
|
| 75 |
+
|
| 76 |
+
print("🤖 Starting terminal chatbot...")
|
| 77 |
+
subprocess.run([sys.executable, "terminal_chatbot.py"])
|
| 78 |
+
|
| 79 |
+
except KeyboardInterrupt:
|
| 80 |
+
print("\n🛑 Stopping...")
|
| 81 |
+
finally:
|
| 82 |
+
# Clean up server process
|
| 83 |
+
print("🧹 Stopping server...")
|
| 84 |
+
server_process.terminate()
|
| 85 |
+
server_process.wait()
|
| 86 |
+
print("✅ Server stopped")
|
| 87 |
+
|
| 88 |
+
else:
|
| 89 |
+
print("❌ Invalid choice. Please run again and select 1, 2, or 3.")
|
| 90 |
+
sys.exit(1)
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
| 94 |
+
|
| 95 |
+
This creates a complete client-server architecture:
|
| 96 |
+
|
| 97 |
+
## 🚀 **Key Features:**
|
| 98 |
+
|
| 99 |
+
### **Terminal Chatbot (`terminal_chatbot.py`)**
|
| 100 |
+
- Beautiful CLI interface with Rich formatting
|
| 101 |
+
- Command support (`/help`, `/lang`, `/temp`, `/clear`, etc.)
|
| 102 |
+
- Real-time API communication
|
| 103 |
+
- Syntax-highlighted code display
|
| 104 |
+
- Conversation history management
|
| 105 |
+
|
| 106 |
+
### **Model Server (`model_server.py`)**
|
| 107 |
+
- FastAPI server hosting the 5B+ parameter model
|
| 108 |
+
- RESTful API endpoints for chat and model info
|
| 109 |
+
- Health monitoring and status checking
|
| 110 |
+
- CORS enabled for web clients
|
| 111 |
+
- Background model loading
|
| 112 |
+
|
| 113 |
+
### **Updated Gradio App (`updated_app.py`)**
|
| 114 |
+
- Works with the API server
|
| 115 |
+
- Real-time status monitoring
|
| 116 |
+
- Same features as before but via API
|
| 117 |
+
|
| 118 |
+
### **Easy Launcher (`run_client.py`)**
|
| 119 |
+
- Simple menu-driven interface
|
| 120 |
+
- Can start server, client, or both
|
| 121 |
+
- Error checking and guidance
|
| 122 |
+
|
| 123 |
+
## 📋 **How to Use:**
|
| 124 |
+
|
| 125 |
+
1. **Start the server:**
|
| 126 |
+
python model_server.py
|
| 127 |
+
|
| 128 |
+
2. **Start the terminal chatbot:**
|
| 129 |
+
python terminal_chatbot.py
|
| 130 |
+
|
| 131 |
+
3. **Or use the easy launcher:**
|
| 132 |
+
python run_client.py
|
| 133 |
+
|
| 134 |
+
4. **For the Gradio web interface:**
|
| 135 |
+
python updated_app.py
|
| 136 |
+
|
| 137 |
+
## 🔗 **API Endpoints:**
|
| 138 |
+
- `GET /health` - Check server status
|
| 139 |
+
- `GET /model/info` - Get model information
|
| 140 |
+
- `POST /api/chat` - Send chat messages
|
| 141 |
+
- `POST /api/validate-code` - Validate code syntax
|
| 142 |
+
- `GET /api/languages` - Get supported languages
|
| 143 |
+
|
| 144 |
+
The terminal chatbot provides a professional CLI experience with syntax highlighting, command support, and real-time API communication!
|
terminal_chatbot.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Terminal-based AI Coding Assistant
|
| 4 |
+
A command-line interface for the 5B parameter coding model via API
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Any
|
| 13 |
+
from rich.console import Console
|
| 14 |
+
from rich.panel import Panel
|
| 15 |
+
from rich.syntax import Syntax
|
| 16 |
+
from rich.prompt import Prompt, Confirm
|
| 17 |
+
from rich.markdown import Markdown
|
| 18 |
+
from rich.table import Table
|
| 19 |
+
from rich import print as rprint
|
| 20 |
+
|
| 21 |
+
console = Console()
|
| 22 |
+
|
| 23 |
+
class TerminalChatbot:
|
| 24 |
+
"""Terminal-based chatbot client for the AI coding assistant."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, server_url: str = "http://localhost:8000"):
|
| 27 |
+
self.server_url = server_url.rstrip('/')
|
| 28 |
+
self.api_url = f"{server_url}/api/chat"
|
| 29 |
+
self.history: List[Dict[str, str]] = []
|
| 30 |
+
self.current_language = "python"
|
| 31 |
+
self.temperature = 0.7
|
| 32 |
+
|
| 33 |
+
def check_server_connection(self) -> bool:
|
| 34 |
+
"""Check if the model server is running."""
|
| 35 |
+
try:
|
| 36 |
+
response = requests.get(f"{self.server_url}/health", timeout=5)
|
| 37 |
+
return response.status_code == 200
|
| 38 |
+
except requests.exceptions.RequestException:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
def send_message(self, message: str) -> Dict[str, Any]:
|
| 42 |
+
"""Send a message to the model server and get response."""
|
| 43 |
+
try:
|
| 44 |
+
payload = {
|
| 45 |
+
"message": message,
|
| 46 |
+
"history": self.history,
|
| 47 |
+
"language": self.current_language,
|
| 48 |
+
"temperature": self.temperature
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
response = requests.post(
|
| 52 |
+
self.api_url,
|
| 53 |
+
json=payload,
|
| 54 |
+
headers={"Content-Type": "application/json"},
|
| 55 |
+
timeout=60
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if response.status_code == 200:
|
| 59 |
+
return response.json()
|
| 60 |
+
else:
|
| 61 |
+
return {
|
| 62 |
+
"choices": [{"message": {"content": f"Server error: {response.status_code}"}}],
|
| 63 |
+
"history": self.history
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
except requests.exceptions.RequestException as e:
|
| 67 |
+
return {
|
| 68 |
+
"choices": [{"message": {"content": f"Connection error: {str(e)}"}}],
|
| 69 |
+
"history": self.history
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def format_response(self, response: str) -> None:
|
| 73 |
+
"""Format and display the model's response."""
|
| 74 |
+
if not response:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
# Split response into code blocks and text
|
| 78 |
+
parts = response.split('```')
|
| 79 |
+
|
| 80 |
+
for i, part in enumerate(parts):
|
| 81 |
+
if i % 2 == 0: # Text parts
|
| 82 |
+
if part.strip():
|
| 83 |
+
try:
|
| 84 |
+
markdown = Markdown(part.strip())
|
| 85 |
+
console.print(markdown)
|
| 86 |
+
except:
|
| 87 |
+
console.print(part.strip())
|
| 88 |
+
else: # Code parts
|
| 89 |
+
lines = part.split('\n', 1)
|
| 90 |
+
if len(lines) >= 2:
|
| 91 |
+
language = lines[0].strip() if lines[0].strip() else 'text'
|
| 92 |
+
code = lines[1]
|
| 93 |
+
else:
|
| 94 |
+
language = 'text'
|
| 95 |
+
code = part
|
| 96 |
+
|
| 97 |
+
if code.strip():
|
| 98 |
+
syntax = Syntax(code.strip(), language, theme="monokai", line_numbers=True)
|
| 99 |
+
console.print(syntax)
|
| 100 |
+
|
| 101 |
+
def show_welcome(self) -> None:
|
| 102 |
+
"""Display welcome message and help."""
|
| 103 |
+
welcome_text = """
|
| 104 |
+
# 🤖 AI Coding Assistant - Terminal Version
|
| 105 |
+
|
| 106 |
+
Welcome to your AI-powered coding companion! I can help you with:
|
| 107 |
+
|
| 108 |
+
• **Code Generation** - Write functions, classes, and complete programs
|
| 109 |
+
• **Debugging** - Find and fix errors in your code
|
| 110 |
+
• **Algorithm Implementation** - From simple to complex algorithms
|
| 111 |
+
• **Best Practices** - Clean, efficient, and readable code
|
| 112 |
+
• **Concept Explanation** - Understand programming concepts
|
| 113 |
+
|
| 114 |
+
## Quick Start Commands:
|
| 115 |
+
• `/help` - Show this help
|
| 116 |
+
• `/lang <language>` - Change programming language
|
| 117 |
+
• `/temp <value>` - Set creativity (0.1-1.0)
|
| 118 |
+
• `/clear` - Clear chat history
|
| 119 |
+
• `/quit` or `/exit` - Exit the program
|
| 120 |
+
|
| 121 |
+
## Example Prompts:
|
| 122 |
+
• "Write a Python function to reverse a linked list"
|
| 123 |
+
• "Create a React component for user authentication"
|
| 124 |
+
• "Explain Big O notation with code examples"
|
| 125 |
+
• "Debug this JavaScript code: [paste your code]"
|
| 126 |
+
|
| 127 |
+
**Ready to code? Just ask me anything!**
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
panel = Panel(
|
| 131 |
+
Markdown(welcome_text),
|
| 132 |
+
title="AI Coder Terminal",
|
| 133 |
+
border_style="blue",
|
| 134 |
+
padding=(1, 2)
|
| 135 |
+
)
|
| 136 |
+
console.print(panel)
|
| 137 |
+
|
| 138 |
+
def show_settings(self) -> None:
|
| 139 |
+
"""Display current settings."""
|
| 140 |
+
table = Table(title="Current Settings")
|
| 141 |
+
table.add_column("Setting", style="cyan")
|
| 142 |
+
table.add_column("Value", style="green")
|
| 143 |
+
table.add_column("Description", style="yellow")
|
| 144 |
+
|
| 145 |
+
table.add_row("Language", self.current_language, "Target programming language")
|
| 146 |
+
table.add_row("Temperature", str(self.temperature), "Creativity level (0.1-1.0)")
|
| 147 |
+
table.add_row("Server", self.server_url, "Model server URL")
|
| 148 |
+
table.add_row("History", str(len(self.history)), "Messages in conversation")
|
| 149 |
+
|
| 150 |
+
console.print(table)
|
| 151 |
+
|
| 152 |
+
def handle_command(self, command: str) -> bool:
|
| 153 |
+
"""Handle special commands. Returns True if command was processed."""
|
| 154 |
+
cmd = command.lower().strip()
|
| 155 |
+
|
| 156 |
+
if cmd in ['/help', '/h']:
|
| 157 |
+
self.show_help()
|
| 158 |
+
return True
|
| 159 |
+
elif cmd.startswith('/lang '):
|
| 160 |
+
language = command.split(' ', 1)[1].strip()
|
| 161 |
+
self.current_language = language
|
| 162 |
+
console.print(f"[green]✓[/green] Language set to: {language}")
|
| 163 |
+
return True
|
| 164 |
+
elif cmd.startswith('/temp '):
|
| 165 |
+
try:
|
| 166 |
+
temp = float(command.split(' ', 1)[1].strip())
|
| 167 |
+
if 0.1 <= temp <= 1.0:
|
| 168 |
+
self.temperature = temp
|
| 169 |
+
console.print(f"[green]✓[/green] Temperature set to: {temp}")
|
| 170 |
+
else:
|
| 171 |
+
console.print("[red]Temperature must be between 0.1 and 1.0[/red]")
|
| 172 |
+
except ValueError:
|
| 173 |
+
console.print("[red]Invalid temperature value[/red]")
|
| 174 |
+
return True
|
| 175 |
+
elif cmd in ['/settings', '/config']:
|
| 176 |
+
self.show_settings()
|
| 177 |
+
return True
|
| 178 |
+
elif cmd in ['/clear', '/reset']:
|
| 179 |
+
self.history = []
|
| 180 |
+
console.print("[green]✓[/green] Chat history cleared")
|
| 181 |
+
return True
|
| 182 |
+
elif cmd in ['/quit', '/exit', '/q']:
|
| 183 |
+
console.print("[yellow]Goodbye! 👋[/yellow]")
|
| 184 |
+
sys.exit(0)
|
| 185 |
+
else:
|
| 186 |
+
console.print(f"[red]Unknown command: {command}[/red]")
|
| 187 |
+
return True
|
| 188 |
+
|
| 189 |
+
def show_help(self) -> None:
|
| 190 |
+
"""Display detailed help information."""
|
| 191 |
+
help_text = """
|
| 192 |
+
# Available Commands
|
| 193 |
+
|
| 194 |
+
## Chat Commands
|
| 195 |
+
- **Regular text**: Ask questions or request code
|
| 196 |
+
- **/help** or **/h**: Show this help message
|
| 197 |
+
- **/settings**: Display current settings
|
| 198 |
+
- **/clear**: Clear chat history
|
| 199 |
+
|
| 200 |
+
## Configuration Commands
|
| 201 |
+
- **/lang <language>**: Change programming language
|
| 202 |
+
- Example: `/lang javascript`
|
| 203 |
+
- **/temp <value>**: Set creativity level (0.1-1.0)
|
| 204 |
+
- Example: `/temp 0.3` (more precise)
|
| 205 |
+
- Example: `/temp 0.9` (more creative)
|
| 206 |
+
|
| 207 |
+
## Exit Commands
|
| 208 |
+
- **/quit** or **/exit**: Exit the program
|
| 209 |
+
|
| 210 |
+
## Programming Languages Supported
|
| 211 |
+
python, javascript, java, cpp, c, go, rust, typescript,
|
| 212 |
+
php, ruby, swift, kotlin, sql, html, css, bash, powershell
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
console.print(Panel(Markdown(help_text), title="Help", border_style="green"))
|
| 216 |
+
|
| 217 |
+
def run(self) -> None:
|
| 218 |
+
"""Main chatbot loop."""
|
| 219 |
+
self.show_welcome()
|
| 220 |
+
|
| 221 |
+
# Check server connection
|
| 222 |
+
console.print("\n[yellow]Checking server connection...[/yellow]")
|
| 223 |
+
if not self.check_server_connection():
|
| 224 |
+
console.print(f"[red]❌ Cannot connect to server at {self.server_url}[/red]")
|
| 225 |
+
console.print("[yellow]💡 Make sure the model server is running with:[/yellow]")
|
| 226 |
+
console.print("[cyan]python model_server.py[/cyan]")
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
console.print("[green]✓[/green] Connected to model server!")
|
| 230 |
+
|
| 231 |
+
# Main interaction loop
|
| 232 |
+
while True:
|
| 233 |
+
try:
|
| 234 |
+
# Get user input
|
| 235 |
+
user_input = Prompt.ask(
|
| 236 |
+
f"[bold blue]You[/bold blue] ({self.current_language})"
|
| 237 |
+
).strip()
|
| 238 |
+
|
| 239 |
+
if not user_input:
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
# Handle commands
|
| 243 |
+
if user_input.startswith('/'):
|
| 244 |
+
self.handle_command(user_input)
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
# Show typing indicator
|
| 248 |
+
with console.status("[bold green]AI is thinking...[/bold green]"):
|
| 249 |
+
start_time = time.time()
|
| 250 |
+
|
| 251 |
+
# Send to server and get response
|
| 252 |
+
response_data = self.send_message(user_input)
|
| 253 |
+
|
| 254 |
+
end_time = time.time()
|
| 255 |
+
response_time = end_time - start_time
|
| 256 |
+
|
| 257 |
+
# Display response
|
| 258 |
+
if response_data and "choices" in response_data:
|
| 259 |
+
response = response_data["choices"][0]["message"]["content"]
|
| 260 |
+
|
| 261 |
+
# Display response with timing
|
| 262 |
+
console.print(f"\n[dim]Response time: {response_time:.2f}s[/dim]")
|
| 263 |
+
console.print(f"[bold green]AI:[/bold green]")
|
| 264 |
+
|
| 265 |
+
# Format and display response
|
| 266 |
+
self.format_response(response)
|
| 267 |
+
|
| 268 |
+
# Update history
|
| 269 |
+
self.history = response_data.get("history", self.history)
|
| 270 |
+
|
| 271 |
+
console.print() # Add spacing
|
| 272 |
+
else:
|
| 273 |
+
console.print("[red]❌ Invalid response from server[/red]")
|
| 274 |
+
|
| 275 |
+
except KeyboardInterrupt:
|
| 276 |
+
console.print("\n[yellow]Interrupted by user[/yellow]")
|
| 277 |
+
if Confirm.ask("Exit the program?"):
|
| 278 |
+
break
|
| 279 |
+
except Exception as e:
|
| 280 |
+
console.print(f"[red]❌ Error: {str(e)}[/red]")
|
| 281 |
+
|
| 282 |
+
def main():
|
| 283 |
+
"""Main entry point."""
|
| 284 |
+
# Check for custom server URL
|
| 285 |
+
server_url = "http://localhost:8000"
|
| 286 |
+
if len(sys.argv) > 1:
|
| 287 |
+
server_url = sys.argv[1]
|
| 288 |
+
|
| 289 |
+
console.print(f"[cyan]AI Coding Assistant Terminal[/cyan]")
|
| 290 |
+
console.print(f"[dim]Server: {server_url}[/dim]\n")
|
| 291 |
+
|
| 292 |
+
# Create and run chatbot
|
| 293 |
+
chatbot = TerminalChatbot(server_url)
|
| 294 |
+
chatbot.run()
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
main()
|
updated_app.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
from typing import List, Dict, Any, Optional
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
# Configuration
|
| 9 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
|
| 10 |
+
API_HEALTH_URL = f"{API_BASE_URL}/health"
|
| 11 |
+
API_CHAT_URL = f"{API_BASE_URL}/api/chat"
|
| 12 |
+
API_INFO_URL = f"{API_BASE_URL}/model/info"
|
| 13 |
+
|
| 14 |
+
def check_api_connection() -> Dict[str, Any]:
|
| 15 |
+
"""Check if the model API server is running."""
|
| 16 |
+
try:
|
| 17 |
+
response = requests.get(API_HEALTH_URL, timeout=5)
|
| 18 |
+
if response.status_code == 200:
|
| 19 |
+
return response.json()
|
| 20 |
+
else:
|
| 21 |
+
return {"status": "error", "message": f"API returned status {response.status_code}"}
|
| 22 |
+
except requests.exceptions.RequestException as e:
|
| 23 |
+
return {"status": "error", "message": f"Connection failed: {str(e)}"}
|
| 24 |
+
|
| 25 |
+
def chat_with_api(message: str, history: List[Dict[str, str]], language: str = "python", temperature: float = 0.7) -> Dict[str, Any]:
|
| 26 |
+
"""Chat function that calls the model API."""
|
| 27 |
+
try:
|
| 28 |
+
# Check API connection first
|
| 29 |
+
health_status = check_api_connection()
|
| 30 |
+
if health_status.get("status") != "healthy":
|
| 31 |
+
return {
|
| 32 |
+
"choices": [{"message": {"content": f"❌ API Server Error: {health_status.get('message', 'Unknown error')}\n\n💡 Make sure the model server is running:\n```bash\npython model_server.py\n```"}}],
|
| 33 |
+
"history": history
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
payload = {
|
| 37 |
+
"message": message,
|
| 38 |
+
"history": history,
|
| 39 |
+
"language": language,
|
| 40 |
+
"temperature": temperature
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
response = requests.post(
|
| 44 |
+
API_CHAT_URL,
|
| 45 |
+
json=payload,
|
| 46 |
+
headers={"Content-Type": "application/json"},
|
| 47 |
+
timeout=60
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if response.status_code == 200:
|
| 51 |
+
return response.json()
|
| 52 |
+
else:
|
| 53 |
+
return {
|
| 54 |
+
"choices": [{"message": {"content": f"API Error: {response.status_code} - {response.text}"}}],
|
| 55 |
+
"history": history
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
except requests.exceptions.RequestException as e:
|
| 59 |
+
return {
|
| 60 |
+
"choices": [{"message": {"content": f"Connection error: {str(e)}"}}],
|
| 61 |
+
"history": history
|
| 62 |
+
}
|
| 63 |
+
except Exception as e:
|
| 64 |
+
return {
|
| 65 |
+
"choices": [{"message": {"content": f"Error: {str(e)}"}}],
|
| 66 |
+
"history": history
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def get_model_info_api() -> Dict[str, Any]:
|
| 70 |
+
"""Get model information from the API."""
|
| 71 |
+
try:
|
| 72 |
+
response = requests.get(API_INFO_URL, timeout=5)
|
| 73 |
+
if response.status_code == 200:
|
| 74 |
+
return response.json()
|
| 75 |
+
else:
|
| 76 |
+
return {"error": f"Failed to get model info: {response.status_code}"}
|
| 77 |
+
except Exception as e:
|
| 78 |
+
return {"error": f"Failed to get model info: {str(e)}"}
|
| 79 |
+
|
| 80 |
+
def create_demo():
|
| 81 |
+
"""Create the Gradio demo interface."""
|
| 82 |
+
|
| 83 |
+
with gr.Blocks(
|
| 84 |
+
title="AI Coder - 5B Parameter Chatbot (API)",
|
| 85 |
+
description="Powered by a 5B parameter language model via API server",
|
| 86 |
+
theme=gr.themes.Soft(),
|
| 87 |
+
css="""
|
| 88 |
+
.container {max-width: 1200px !important;}
|
| 89 |
+
.header {text-align: center; padding: 20px;}
|
| 90 |
+
.header h1 {color: #2d3748; margin-bottom: 10px;}
|
| 91 |
+
.header a {color: #3182ce; text-decoration: none; font-weight: bold;}
|
| 92 |
+
.header a:hover {text-decoration: underline;}
|
| 93 |
+
.status-indicator {padding: 10px; border-radius: 5px; margin: 10px 0;}
|
| 94 |
+
.status-online {background-color: #d4edda; color: #155724;}
|
| 95 |
+
.status-offline {background-color: #f8d7da; color: #721c24;}
|
| 96 |
+
.coding-section {background: #f7fafc; border-radius: 8px; padding: 15px; margin: 10px 0;}
|
| 97 |
+
"""
|
| 98 |
+
) as demo:
|
| 99 |
+
|
| 100 |
+
# Header
|
| 101 |
+
gr.HTML("""
|
| 102 |
+
<div class="header">
|
| 103 |
+
<h1>🤖 AI Coder - API Client</h1>
|
| 104 |
+
<p>AI chatbot with coding features powered by a 5B parameter model via API</p>
|
| 105 |
+
<p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
|
| 106 |
+
</div>
|
| 107 |
+
""")
|
| 108 |
+
|
| 109 |
+
# Status indicator
|
| 110 |
+
status_display = gr.HTML()
|
| 111 |
+
|
| 112 |
+
def update_status():
|
| 113 |
+
status = check_api_connection()
|
| 114 |
+
if status.get("status") == "healthy":
|
| 115 |
+
return f"""
|
| 116 |
+
<div class="status-indicator status-online">
|
| 117 |
+
✅ API Server: Online - Model: {status.get('model_name', 'Unknown')}
|
| 118 |
+
</div>
|
| 119 |
+
"""
|
| 120 |
+
else:
|
| 121 |
+
return f"""
|
| 122 |
+
<div class="status-indicator status-offline">
|
| 123 |
+
❌ API Server: Offline - {status.get('message', 'Unknown error')}
|
| 124 |
+
</div>
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
# Main chat interface
|
| 128 |
+
with gr.Row():
|
| 129 |
+
# Left column - Chat
|
| 130 |
+
with gr.Column(scale=3):
|
| 131 |
+
chatbot = gr.Chatbot(
|
| 132 |
+
label="AI Coding Assistant",
|
| 133 |
+
height=600,
|
| 134 |
+
type="messages",
|
| 135 |
+
avatar_images=(None, "🤖"),
|
| 136 |
+
show_copy_button=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
with gr.Row():
|
| 140 |
+
msg = gr.Textbox(
|
| 141 |
+
placeholder="Ask me to code something, debug code, or explain programming concepts...",
|
| 142 |
+
lines=3,
|
| 143 |
+
scale=4
|
| 144 |
+
)
|
| 145 |
+
send_btn = gr.Button("Send", variant="primary", scale=1)
|
| 146 |
+
|
| 147 |
+
with gr.Row():
|
| 148 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
| 149 |
+
|
| 150 |
+
# Right column - Controls
|
| 151 |
+
with gr.Column(scale=1):
|
| 152 |
+
gr.Markdown("### 🛠️ Settings")
|
| 153 |
+
|
| 154 |
+
language = gr.Dropdown(
|
| 155 |
+
choices=[
|
| 156 |
+
"python", "javascript", "java", "cpp", "c", "go",
|
| 157 |
+
"rust", "typescript", "php", "ruby", "swift", "kotlin",
|
| 158 |
+
"sql", "html", "css", "bash", "powershell"
|
| 159 |
+
],
|
| 160 |
+
value="python",
|
| 161 |
+
label="Programming Language"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
temperature = gr.Slider(
|
| 165 |
+
minimum=0.1,
|
| 166 |
+
maximum=1.0,
|
| 167 |
+
value=0.7,
|
| 168 |
+
step=0.1,
|
| 169 |
+
label="Creativity (Temperature)"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# API Status info
|
| 173 |
+
with gr.Accordion("🔗 API Status", open=True):
|
| 174 |
+
status_text = gr.Markdown()
|
| 175 |
+
|
| 176 |
+
with gr.Accordion("🎯 Quick Prompts", open=False):
|
| 177 |
+
gr.Examples(
|
| 178 |
+
examples=[
|
| 179 |
+
"Write a Python function to reverse a linked list",
|
| 180 |
+
"Create a React component for a login form",
|
| 181 |
+
"Debug this JavaScript code: [paste code]",
|
| 182 |
+
"Explain Big O notation with examples",
|
| 183 |
+
"Create a binary search algorithm in C++"
|
| 184 |
+
],
|
| 185 |
+
inputs=msg,
|
| 186 |
+
examples_per_page=3
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
with gr.Accordion("ℹ️ API Info", open=False):
|
| 190 |
+
api_info = gr.Markdown()
|
| 191 |
+
|
| 192 |
+
def get_api_info():
|
| 193 |
+
info = get_model_info_api()
|
| 194 |
+
if "error" not in info:
|
| 195 |
+
return f"""
|
| 196 |
+
**Model:** {info.get('model_name', 'Unknown')}
|
| 197 |
+
**Parameters:** {info.get('parameter_count', 'Unknown')}
|
| 198 |
+
**Max Length:** {info.get('max_length', 'Unknown'):,} tokens
|
| 199 |
+
**Device:** {info.get('device', 'Unknown')}
|
| 200 |
+
**Status:** {'✅ Loaded' if info.get('is_loaded') else '⏳ Loading...'}
|
| 201 |
+
**Vocab Size:** {info.get('vocab_size', 'Unknown'):,}
|
| 202 |
+
"""
|
| 203 |
+
else:
|
| 204 |
+
return f"❌ {info['error']}"
|
| 205 |
+
|
| 206 |
+
api_info.value = get_api_info()
|
| 207 |
+
|
| 208 |
+
# Event handlers
|
| 209 |
+
def user(user_message, history):
|
| 210 |
+
return "", history + [{"role": "user", "content": user_message}]
|
| 211 |
+
|
| 212 |
+
def bot(history, selected_language, temp):
|
| 213 |
+
if not history:
|
| 214 |
+
return history
|
| 215 |
+
|
| 216 |
+
last_message = history[-1]["content"]
|
| 217 |
+
result = chat_with_api(last_message, history[:-1], selected_language, temp)
|
| 218 |
+
return result["history"]
|
| 219 |
+
|
| 220 |
+
# Wire up events
|
| 221 |
+
msg.submit(
|
| 222 |
+
user,
|
| 223 |
+
[msg, chatbot],
|
| 224 |
+
[msg, chatbot],
|
| 225 |
+
queue=False
|
| 226 |
+
).then(
|
| 227 |
+
bot,
|
| 228 |
+
[chatbot, language, temperature],
|
| 229 |
+
chatbot
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
send_btn.click(
|
| 233 |
+
user,
|
| 234 |
+
[msg, chatbot],
|
| 235 |
+
[msg, chatbot],
|
| 236 |
+
queue=False
|
| 237 |
+
).then(
|
| 238 |
+
bot,
|
| 239 |
+
[chatbot, language, temperature],
|
| 240 |
+
chatbot
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
clear_btn.click(
|
| 244 |
+
lambda: [{"role": "assistant", "content": "Hello! I'm your AI coding assistant. I can help you with Python, JavaScript, Java, C++, and many other programming languages. What would you like to code today?"}],
|
| 245 |
+
outputs=[chatbot]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Update status periodically
|
| 249 |
+
def update_all_status():
|
| 250 |
+
status_html = update_status()
|
| 251 |
+
api_info_text = get_api_info()
|
| 252 |
+
return status_html, api_info_text
|
| 253 |
+
|
| 254 |
+
# Initial status update
|
| 255 |
+
status_display.value = update_status()
|
| 256 |
+
|
| 257 |
+
# Load initial message
|
| 258 |
+
chatbot.value = [{"role": "assistant", "content": "Hello! I'm your AI coding assistant powered by a 5B parameter language model via API. I can help you with Python, JavaScript, Java, C++, and many other programming languages. What would you like to code today?"}]
|
| 259 |
+
|
| 260 |
+
return demo
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
demo = create_demo()
|
| 264 |
+
demo.launch(
|
| 265 |
+
server_name="0.0.0.0",
|
| 266 |
+
server_port=7860,
|
| 267 |
+
show_error=True,
|
| 268 |
+
share=False,
|
| 269 |
+
debug=True,
|
| 270 |
+
mcp_server=True
|
| 271 |
+
)
|
updated_models.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
import logging
|
| 5 |
+
import asyncio
|
| 6 |
+
import threading
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class CodeModel:
|
| 11 |
+
"""5B Parameter coding model wrapper with optimized inference."""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.model_name = "bigcode/starcoder2-7b" # 7B model (closest to 5B with excellent coding)
|
| 15 |
+
self.parameter_count = "7B"
|
| 16 |
+
self.max_length = 16384
|
| 17 |
+
self.tokenizer = None
|
| 18 |
+
self.model = None
|
| 19 |
+
self.pipeline = None
|
| 20 |
+
self.is_loaded = False
|
| 21 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
self._lock = threading.Lock()
|
| 23 |
+
|
| 24 |
+
@spaces.GPU(duration=1200) # Extended duration for model loading
|
| 25 |
+
def load_model(self):
|
| 26 |
+
"""Load the model (called via spaces decorator for optimization)."""
|
| 27 |
+
try:
|
| 28 |
+
logger.info(f"Loading {self.model_name} model...")
|
| 29 |
+
|
| 30 |
+
# Load tokenizer and model
|
| 31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 32 |
+
self.model_name,
|
| 33 |
+
trust_remote_code=True,
|
| 34 |
+
padding_side="left"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Set pad token if not present
|
| 38 |
+
if self.tokenizer.pad_token is None:
|
| 39 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 40 |
+
|
| 41 |
+
# Load model with optimization
|
| 42 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
+
self.model_name,
|
| 44 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 45 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 46 |
+
trust_remote_code=True,
|
| 47 |
+
low_cpu_mem_usage=True,
|
| 48 |
+
use_cache=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Set model to evaluation mode
|
| 52 |
+
self.model.eval()
|
| 53 |
+
|
| 54 |
+
# Create pipeline for easier inference
|
| 55 |
+
self.pipeline = pipeline(
|
| 56 |
+
"text-generation",
|
| 57 |
+
model=self.model,
|
| 58 |
+
tokenizer=self.tokenizer,
|
| 59 |
+
device=0 if self.device == "cuda" else -1,
|
| 60 |
+
do_sample=True,
|
| 61 |
+
temperature=0.7,
|
| 62 |
+
top_p=0.95,
|
| 63 |
+
repetition_penalty=1.1,
|
| 64 |
+
max_new_tokens=2048,
|
| 65 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.is_loaded = True
|
| 69 |
+
logger.info(f"✅ {self.model_name} loaded successfully on {self.device}")
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"❌ Error loading model: {e}")
|
| 73 |
+
self._fallback_model()
|
| 74 |
+
|
| 75 |
+
def _fallback_model(self):
|
| 76 |
+
"""Fallback to a smaller model if the main model fails to load."""
|
| 77 |
+
try:
|
| 78 |
+
logger.info("Trying fallback model: microsoft/DialoGPT-medium")
|
| 79 |
+
self.model_name = "microsoft/DialoGPT-medium"
|
| 80 |
+
self.parameter_count = "345M"
|
| 81 |
+
self.max_length = 1024
|
| 82 |
+
|
| 83 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 84 |
+
if self.tokenizer.pad_token is None:
|
| 85 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 88 |
+
self.model_name,
|
| 89 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 90 |
+
device_map="auto" if self.device == "cuda" else None
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.pipeline = pipeline(
|
| 94 |
+
"text-generation",
|
| 95 |
+
model=self.model,
|
| 96 |
+
tokenizer=self.tokenizer,
|
| 97 |
+
device=0 if self.device == "cuda" else -1,
|
| 98 |
+
max_new_tokens=512,
|
| 99 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.is_loaded = True
|
| 103 |
+
logger.info(f"✅ Fallback model loaded successfully")
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"❌ Fallback model also failed: {e}")
|
| 107 |
+
self.is_loaded = False
|
| 108 |
+
|
| 109 |
+
def generate(
|
| 110 |
+
self,
|
| 111 |
+
messages: List[Dict[str, str]],
|
| 112 |
+
temperature: float = 0.7,
|
| 113 |
+
max_new_tokens: int = 2048,
|
| 114 |
+
language: str = "python"
|
| 115 |
+
) -> str:
|
| 116 |
+
"""Generate response from the model."""
|
| 117 |
+
|
| 118 |
+
if not self.is_loaded:
|
| 119 |
+
return "I'm sorry, the model is not loaded yet. Please try again in a moment."
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
with self._lock: # Ensure thread-safe access
|
| 123 |
+
# Convert chat format to text
|
| 124 |
+
if messages:
|
| 125 |
+
# Format as conversation
|
| 126 |
+
conversation = ""
|
| 127 |
+
for msg in messages:
|
| 128 |
+
role = msg["role"]
|
| 129 |
+
content = msg["content"]
|
| 130 |
+
if role == "system":
|
| 131 |
+
conversation += f"System: {content}\n\n"
|
| 132 |
+
elif role == "user":
|
| 133 |
+
conversation += f"Human: {content}\n"
|
| 134 |
+
elif role == "assistant":
|
| 135 |
+
conversation += f"Assistant: {content}\n"
|
| 136 |
+
|
| 137 |
+
# Add specific coding instructions
|
| 138 |
+
if "write" in conversation.lower() or "code" in conversation.lower():
|
| 139 |
+
conversation += f"\n\nPlease provide clean, well-commented {language} code with proper syntax and best practices."
|
| 140 |
+
|
| 141 |
+
conversation += "\nAssistant:"
|
| 142 |
+
|
| 143 |
+
# Generate response
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
if self.pipeline:
|
| 146 |
+
# Use pipeline for generation
|
| 147 |
+
outputs = self.pipeline(
|
| 148 |
+
conversation,
|
| 149 |
+
do_sample=True,
|
| 150 |
+
temperature=temperature,
|
| 151 |
+
top_p=0.95,
|
| 152 |
+
repetition_penalty=1.1,
|
| 153 |
+
max_new_tokens=max_new_tokens,
|
| 154 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 155 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 156 |
+
return_full_text=False,
|
| 157 |
+
clean_up_tokenization_spaces=True
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if outputs and len(outputs) > 0:
|
| 161 |
+
response = outputs[0]["generated_text"].strip()
|
| 162 |
+
return response
|
| 163 |
+
|
| 164 |
+
# Fallback to direct model generation
|
| 165 |
+
inputs = self.tokenizer.encode(conversation, return_tensors="pt").to(self.device)
|
| 166 |
+
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
outputs = self.model.generate(
|
| 169 |
+
inputs,
|
| 170 |
+
do_sample=True,
|
| 171 |
+
temperature=temperature,
|
| 172 |
+
top_p=0.95,
|
| 173 |
+
repetition_penalty=1.1,
|
| 174 |
+
max_new_tokens=max_new_tokens,
|
| 175 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 176 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 177 |
+
attention_mask=torch.ones_like(inputs)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Decode response
|
| 181 |
+
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 182 |
+
return response.strip()
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.error(f"Generation error: {e}")
|
| 186 |
+
return f"I apologize, but I encountered an error while generating the response: {str(e)}"
|
| 187 |
+
|
| 188 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 189 |
+
"""Get information about the loaded model."""
|
| 190 |
+
return {
|
| 191 |
+
"model_name": self.model_name,
|
| 192 |
+
"parameter_count": self.parameter_count,
|
| 193 |
+
"max_length": self.max_length,
|
| 194 |
+
"device": self.device,
|
| 195 |
+
"is_loaded": self.is_loaded,
|
| 196 |
+
"vocab_size": len(self.tokenizer) if self.tokenizer else 0
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# Global model instance for the server
|
| 200 |
+
_global_model = None
|
| 201 |
+
|
| 202 |
+
def get_model():
|
| 203 |
+
"""Get or create the global model instance."""
|
| 204 |
+
global _global_model
|
| 205 |
+
if _global_model is None:
|
| 206 |
+
_global_model = CodeModel()
|
| 207 |
+
# Load model asynchronously
|
| 208 |
+
threading.Thread(target=_global_model.load_model, daemon=True).start()
|
| 209 |
+
return _global_model
|
| 210 |
+
|
| 211 |
+
def CodeModel():
|
| 212 |
+
"""Factory function for creating CodeModel instances."""
|
| 213 |
+
return CodeModel()
|