from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from typing import List, Optional import torch import asyncio from threading import Thread # ── APP SETUP ───────────────────────────────────────── app = FastAPI(title="DevOS AI", description="AI coding agent by Cool Shot System") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ── MODEL LOADING ───────────────────────────────────── MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-instruct" print(f"Loading model: {MODEL_NAME} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, # CPU-safe low_cpu_mem_usage=True, ) model.eval() print("Model loaded ✓") # ── SCHEMAS ─────────────────────────────────────────── class CodeRequest(BaseModel): code: str language: str = "python" max_tokens: int = 128 class ChatMessage(BaseModel): role: str # "user" or "assistant" content: str class ChatRequest(BaseModel): messages: List[ChatMessage] system: Optional[str] = "" max_tokens: int = 1024 # ── HELPERS ─────────────────────────────────────────── def build_prompt(messages: List[ChatMessage], system: str = "") -> str: prompt = system.strip() + "\n\n" if system and system.strip() else "" for msg in messages[-10:]: # last 10 messages for context window role_label = "User" if msg.role == "user" else "DevOS AI" prompt += f"{role_label}: {msg.content.strip()}\n" prompt += "DevOS AI:" return prompt # ── ROUTES ──────────────────────────────────────────── @app.get("/") def root(): return { "status": "DevOS AI is running", "model": MODEL_NAME, "endpoints": ["/complete", "/chat", "/stream"] } @app.get("/health") def health(): return {"status": "ok"} # ── /complete — inline code completion ──────────────── @app.post("/complete") def complete_code(request: CodeRequest): prompt = f"Continue the following {request.language} code:\n{request.code}" inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 ) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=0.2, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) suggestion = generated[len(prompt):].strip() return {"suggestion": suggestion} # ── /chat — full conversation, single response ───────── @app.post("/chat") def chat(request: ChatRequest): prompt = build_prompt(request.messages, request.system) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 ) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=0.4, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) reply = generated[len(prompt):].strip() return {"reply": reply} # ── /stream — streaming response (SSE) ──────────────── @app.post("/stream") async def stream_chat(request: ChatRequest): prompt = build_prompt(request.messages, request.system) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 ) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( **inputs, max_new_tokens=request.max_tokens, temperature=0.4, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, streamer=streamer, ) # Run generation in background thread so we can stream thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() async def token_generator(): for token in streamer: if token: # SSE format yield f"data: {token}\n\n" await asyncio.sleep(0) # yield control to event loop yield "data: [DONE]\n\n" return StreamingResponse( token_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive", } )