ey-catalyst / app.py
bsny's picture
Update app.py
00e9489 verified
raw
history blame
1.64 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import uuid
app = FastAPI()
import os; os.environ["HF_HOME"] = "/tmp/huggingface"
model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
hf_token = os.environ.get("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True,
token=hf_token
)
# Store per-session system prompts
session_prompts = {}
class SystemPrompt(BaseModel):
prompt: str
class UserMessage(BaseModel):
session_id: str
message: str
@app.post("/start")
def start_chat(system_prompt: SystemPrompt):
session_id = str(uuid.uuid4())
session_prompts[session_id] = system_prompt.prompt
return {"session_id": session_id}
@app.post("/chat")
def chat(message: UserMessage):
system = session_prompts.get(message.session_id)
if not system:
return {"error": "Invalid session_id. Call /start first."}
full_prompt = f"<|system|>\n{system}\n<|user|>\n{message.message}\n<|assistant|>\n"
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=200,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Strip input part to isolate model's answer
answer = response.replace(full_prompt.strip(), "").strip()
return {"response": answer}