Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI, Form | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import json | |
| # --------------------------- | |
| # Models | |
| # --------------------------- | |
| MODEL_OPTIONS = { | |
| "DeepSeek Coder 1.3B": "deepseek-ai/deepseek-coder-1.3b-instruct", | |
| "StarCoder 1B": "bigcode/starcoderbase-1b", | |
| "CodeLLaMA 7B": "codellama/CodeLLaMA-7b-Instruct-hf" | |
| } | |
| loaded_models = {} | |
| def get_model(model_key): | |
| if model_key not in loaded_models: | |
| model_name = MODEL_OPTIONS[model_key] | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" | |
| ) | |
| loaded_models[model_key] = (tokenizer, model) | |
| return loaded_models[model_key] | |
| # --------------------------- | |
| # FastAPI setup | |
| # --------------------------- | |
| app = FastAPI() | |
| # --------------------------- | |
| # Routes | |
| # --------------------------- | |
| async def home(): | |
| # Just serve index.html directly from root | |
| return FileResponse("index.html") | |
| async def css(): | |
| # Serve CSS directly | |
| return FileResponse("style.css") | |
| async def chat(user_input: str = Form(...), model_choice: str = Form(...), history: str = Form("[]")): | |
| history = json.loads(history) | |
| tokenizer, model = get_model(model_choice) | |
| # Build messages | |
| messages = [] | |
| for role, content in history: | |
| messages.append({"role": "user" if role == "user" else "assistant", "content": content}) | |
| messages.append({"role": "user", "content": user_input}) | |
| # Tokenize | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9) | |
| response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| # Update history | |
| history.append(("user", user_input)) | |
| history.append(("assistant", response)) | |
| return JSONResponse({"response": response, "history": history}) | |