Batrdj commited on
Commit
4ceec83
·
verified ·
1 Parent(s): 5a1a2fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -45
app.py CHANGED
@@ -1,45 +1,94 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
- import torch
5
-
6
- app = FastAPI()
7
-
8
- MODEL_NAME = "Qwen/Qwen2.5-Coder-7B"
9
-
10
- # ---- Quantization config (CPU safe) ----
11
- bnb_config = BitsAndBytesConfig(
12
- load_in_4bit=True,
13
- bnb_4bit_compute_dtype=torch.float32,
14
- bnb_4bit_use_double_quant=True,
15
- bnb_4bit_quant_type="nf4"
16
- )
17
-
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- MODEL_NAME,
20
- trust_remote_code=True
21
- )
22
-
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_NAME,
25
- device_map="cpu",
26
- quantization_config=bnb_config,
27
- trust_remote_code=True
28
- )
29
-
30
- class Prompt(BaseModel):
31
- message: str
32
-
33
- @app.post("/chat")
34
- def chat(prompt: Prompt):
35
- inputs = tokenizer(prompt.message, return_tensors="pt")
36
-
37
- outputs = model.generate(
38
- **inputs,
39
- max_new_tokens=200,
40
- temperature=0.7,
41
- do_sample=True
42
- )
43
-
44
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- return {"response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ BitsAndBytesConfig,
8
+ TextIteratorStreamer
9
+ )
10
+ import torch
11
+ import threading
12
+
13
+ app = FastAPI()
14
+
15
+ MODEL_NAME = "Qwen/Qwen2.5-Coder-7B"
16
+
17
+ # ---- Quantization config (CPU safe) ----
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_compute_dtype=torch.float32,
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_quant_type="nf4"
23
+ )
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ MODEL_NAME,
27
+ trust_remote_code=True
28
+ )
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_NAME,
32
+ device_map="cpu",
33
+ quantization_config=bnb_config,
34
+ trust_remote_code=True
35
+ )
36
+
37
+ class Prompt(BaseModel):
38
+ message: str
39
+
40
+ # -------------------------------------------------
41
+ # ✅ NORMAL CHAT (UNCHANGED)
42
+ # -------------------------------------------------
43
+ @app.post("/chat")
44
+ def chat(prompt: Prompt):
45
+ inputs = tokenizer(prompt.message, return_tensors="pt")
46
+
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=200,
50
+ temperature=0.7,
51
+ do_sample=True
52
+ )
53
+
54
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+ return {"response": response}
56
+
57
+
58
+ # -------------------------------------------------
59
+ # 🚀 STREAMING CHAT (CHATGPT-LIKE)
60
+ # -------------------------------------------------
61
+ @app.post("/chat-stream")
62
+ def chat_stream(prompt: Prompt):
63
+
64
+ inputs = tokenizer(prompt.message, return_tensors="pt")
65
+
66
+ streamer = TextIteratorStreamer(
67
+ tokenizer,
68
+ skip_special_tokens=True,
69
+ skip_prompt=True
70
+ )
71
+
72
+ generation_kwargs = dict(
73
+ **inputs,
74
+ streamer=streamer,
75
+ max_new_tokens=200,
76
+ temperature=0.7,
77
+ do_sample=True
78
+ )
79
+
80
+ # Run generation in background thread
81
+ thread = threading.Thread(
82
+ target=model.generate,
83
+ kwargs=generation_kwargs
84
+ )
85
+ thread.start()
86
+
87
+ def token_generator():
88
+ for token in streamer:
89
+ yield token
90
+
91
+ return StreamingResponse(
92
+ token_generator(),
93
+ media_type="text/plain"
94
+ )