DSDUDEd commited on
Commit
2383424
·
verified ·
1 Parent(s): da70be6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Request, Form
3
+ from fastapi.responses import HTMLResponse, JSONResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ # ---------------------------
9
+ # Models
10
+ # ---------------------------
11
+ MODEL_OPTIONS = {
12
+ "DeepSeek Coder 1.3B": "deepseek-ai/deepseek-coder-1.3b-instruct",
13
+ "StarCoder 1B": "bigcode/starcoderbase-1b",
14
+ "CodeLLaMA 7B": "codellama/CodeLlama-7b-Instruct-hf"
15
+ }
16
+ loaded_models = {}
17
+
18
+ def get_model(model_key):
19
+ if model_key not in loaded_models:
20
+ model_name = MODEL_OPTIONS[model_key]
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
+ device_map="auto"
26
+ )
27
+ loaded_models[model_key] = (tokenizer, model)
28
+ return loaded_models[model_key]
29
+
30
+ # ---------------------------
31
+ # FastAPI setup
32
+ # ---------------------------
33
+ app = FastAPI()
34
+ app.mount("/static", StaticFiles(directory="static"), name="static")
35
+ templates = Jinja2Templates(directory=".")
36
+
37
+ # ---------------------------
38
+ # Routes
39
+ # ---------------------------
40
+ @app.get("/", response_class=HTMLResponse)
41
+ async def home(request: Request):
42
+ return templates.TemplateResponse("index.html", {"request": request, "models": list(MODEL_OPTIONS.keys())})
43
+
44
+ @app.post("/chat")
45
+ async def chat(user_input: str = Form(...), model_choice: str = Form(...), history: str = Form("[]")):
46
+ import json
47
+ history = json.loads(history)
48
+
49
+ tokenizer, model = get_model(model_choice)
50
+
51
+ # Build messages
52
+ messages = []
53
+ for role, content in history:
54
+ messages.append({"role": "user" if role == "user" else "assistant", "content": content})
55
+ messages.append({"role": "user", "content": user_input})
56
+
57
+ # Tokenize
58
+ inputs = tokenizer.apply_chat_template(
59
+ messages,
60
+ add_generation_prompt=True,
61
+ tokenize=True,
62
+ return_dict=True,
63
+ return_tensors="pt"
64
+ ).to(model.device)
65
+
66
+ outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
67
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
68
+
69
+ # Update history
70
+ history.append(("user", user_input))
71
+ history.append(("assistant", response))
72
+
73
+ return JSONResponse({"response": response, "history": history})