import torch from fastapi import FastAPI, Form from fastapi.responses import HTMLResponse, JSONResponse, FileResponse from transformers import AutoTokenizer, AutoModelForCausalLM import json # --------------------------- # Models # --------------------------- MODEL_OPTIONS = { "DeepSeek Coder 1.3B": "deepseek-ai/deepseek-coder-1.3b-instruct", "StarCoder 1B": "bigcode/starcoderbase-1b", "CodeLLaMA 7B": "codellama/CodeLLaMA-7b-Instruct-hf" } loaded_models = {} def get_model(model_key): if model_key not in loaded_models: model_name = MODEL_OPTIONS[model_key] tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) loaded_models[model_key] = (tokenizer, model) return loaded_models[model_key] # --------------------------- # FastAPI setup # --------------------------- app = FastAPI() # --------------------------- # Routes # --------------------------- @app.get("/", response_class=HTMLResponse) async def home(): # Just serve index.html directly from root return FileResponse("index.html") @app.get("/style.css") async def css(): # Serve CSS directly return FileResponse("style.css") @app.post("/chat") async def chat(user_input: str = Form(...), model_choice: str = Form(...), history: str = Form("[]")): history = json.loads(history) tokenizer, model = get_model(model_choice) # Build messages messages = [] for role, content in history: messages.append({"role": "user" if role == "user" else "assistant", "content": content}) messages.append({"role": "user", "content": user_input}) # Tokenize inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9) response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) # Update history history.append(("user", user_input)) history.append(("assistant", response)) return JSONResponse({"response": response, "history": history})