Valtry commited on
Commit
20bcc59
Β·
verified Β·
1 Parent(s): 3605faf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -14
app.py CHANGED
@@ -1,7 +1,9 @@
1
- from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
  import uvicorn
6
 
7
  # -----------------------
@@ -24,33 +26,82 @@ torch.set_num_threads(2)
24
  # -----------------------
25
  app = FastAPI()
26
 
 
 
 
27
  class ChatRequest(BaseModel):
28
  message: str
29
 
 
30
  @app.get("/")
31
  def home():
32
- return {"status": "API running πŸš€"}
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @app.post("/chat")
35
- def chat(req: ChatRequest):
36
- prompt = f"User: {req.message}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  inputs = tokenizer(prompt, return_tensors="pt")
39
 
40
- outputs = model.generate(
41
- **inputs,
42
- max_new_tokens=80,
43
- temperature=0.7,
44
- do_sample=True
45
  )
46
 
47
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
- reply = reply.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return {"response": reply}
51
 
52
  # -----------------------
53
- # START SERVER DIRECTLY
54
  # -----------------------
55
  if __name__ == "__main__":
56
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ from threading import Thread
7
  import uvicorn
8
 
9
  # -----------------------
 
26
  # -----------------------
27
  app = FastAPI()
28
 
29
+ # stop flag (global)
30
+ stop_generation = False
31
+
32
  class ChatRequest(BaseModel):
33
  message: str
34
 
35
+
36
  @app.get("/")
37
  def home():
38
+ return {"status": "Streaming API running πŸš€"}
39
+
40
 
41
+ # -----------------------
42
+ # STOP ENDPOINT
43
+ # -----------------------
44
+ @app.post("/stop")
45
+ def stop():
46
+ global stop_generation
47
+ stop_generation = True
48
+ return {"status": "stopping"}
49
+
50
+
51
+ # -----------------------
52
+ # STREAMING CHAT
53
+ # -----------------------
54
  @app.post("/chat")
55
+ async def chat(req: ChatRequest):
56
+
57
+ global stop_generation
58
+ stop_generation = False
59
+
60
+ # πŸ”₯ FORCE SHORT ANSWERS
61
+ prompt = f"""
62
+ You are a concise assistant.
63
+ Answer VERY SHORT (1-2 lines max).
64
+ No long explanations.
65
+
66
+ User: {req.message}
67
+ Assistant:
68
+ """
69
 
70
  inputs = tokenizer(prompt, return_tensors="pt")
71
 
72
+ streamer = TextIteratorStreamer(
73
+ tokenizer,
74
+ skip_prompt=True,
75
+ skip_special_tokens=True
 
76
  )
77
 
78
+ def generate():
79
+ model.generate(
80
+ **inputs,
81
+ streamer=streamer,
82
+ max_new_tokens=40, # πŸ”₯ short output
83
+ temperature=0.6,
84
+ do_sample=True,
85
+ eos_token_id=tokenizer.eos_token_id
86
+ )
87
+
88
+ thread = Thread(target=generate)
89
+ thread.start()
90
+
91
+ async def stream():
92
+ global stop_generation
93
+
94
+ for token in streamer:
95
+ if stop_generation:
96
+ break
97
+
98
+ yield token # πŸ”₯ real-time streaming
99
+
100
+ return StreamingResponse(stream(), media_type="text/plain")
101
 
 
102
 
103
  # -----------------------
104
+ # START SERVER
105
  # -----------------------
106
  if __name__ == "__main__":
107
  uvicorn.run(app, host="0.0.0.0", port=7860)