Fine-Coder / app.py
DSDUDEd's picture
Create app.py
2383424 verified
raw
history blame
2.52 kB
import torch
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from transformers import AutoTokenizer, AutoModelForCausalLM
# ---------------------------
# Models
# ---------------------------
MODEL_OPTIONS = {
"DeepSeek Coder 1.3B": "deepseek-ai/deepseek-coder-1.3b-instruct",
"StarCoder 1B": "bigcode/starcoderbase-1b",
"CodeLLaMA 7B": "codellama/CodeLlama-7b-Instruct-hf"
}
loaded_models = {}
def get_model(model_key):
if model_key not in loaded_models:
model_name = MODEL_OPTIONS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
loaded_models[model_key] = (tokenizer, model)
return loaded_models[model_key]
# ---------------------------
# FastAPI setup
# ---------------------------
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory=".")
# ---------------------------
# Routes
# ---------------------------
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "models": list(MODEL_OPTIONS.keys())})
@app.post("/chat")
async def chat(user_input: str = Form(...), model_choice: str = Form(...), history: str = Form("[]")):
import json
history = json.loads(history)
tokenizer, model = get_model(model_choice)
# Build messages
messages = []
for role, content in history:
messages.append({"role": "user" if role == "user" else "assistant", "content": content})
messages.append({"role": "user", "content": user_input})
# Tokenize
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
# Update history
history.append(("user", user_input))
history.append(("assistant", response))
return JSONResponse({"response": response, "history": history})