bsny commited on
Commit
6f5caa7
·
verified ·
1 Parent(s): c97c23d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -7
app.py CHANGED
@@ -1,20 +1,56 @@
1
- import os
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
4
 
5
- model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
6
  hf_token = os.environ.get("HF_TOKEN")
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
11
- device_map="auto",
12
  torch_dtype=torch.float16,
 
13
  low_cpu_mem_usage=True,
14
  token=hf_token
15
  )
16
 
17
- def generate(prompt):
18
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
19
- outputs = model.generate(**inputs, max_new_tokens=128)
20
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ import os
6
+ import uuid
7
+
8
+ app = FastAPI()
9
 
10
+ model_id = "huggingface/your-quantized-model-id"
11
  hf_token = os.environ.get("HF_TOKEN")
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_id,
 
16
  torch_dtype=torch.float16,
17
+ device_map="auto",
18
  low_cpu_mem_usage=True,
19
  token=hf_token
20
  )
21
 
22
+ # Store per-session system prompts
23
+ session_prompts = {}
24
+
25
+ class SystemPrompt(BaseModel):
26
+ prompt: str
27
+
28
+ class UserMessage(BaseModel):
29
+ session_id: str
30
+ message: str
31
+
32
+ @app.post("/start")
33
+ def start_chat(system_prompt: SystemPrompt):
34
+ session_id = str(uuid.uuid4())
35
+ session_prompts[session_id] = system_prompt.prompt
36
+ return {"session_id": session_id}
37
+
38
+ @app.post("/chat")
39
+ def chat(message: UserMessage):
40
+ system = session_prompts.get(message.session_id)
41
+ if not system:
42
+ return {"error": "Invalid session_id. Call /start first."}
43
+
44
+ full_prompt = f"<|system|>\n{system}\n<|user|>\n{message.message}\n<|assistant|>\n"
45
+
46
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=200,
50
+ pad_token_id=tokenizer.eos_token_id,
51
+ )
52
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+
54
+ # Strip input part to isolate model's answer
55
+ answer = response.replace(full_prompt.strip(), "").strip()
56
+ return {"response": answer}