Raiff1982 commited on
Commit
c8ea644
Β·
verified Β·
1 Parent(s): bfb331a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -75
app.py CHANGED
@@ -1,15 +1,11 @@
1
  """
2
  Codette AI Space β€” FastAPI + streaming chat API
3
- Compatible with the Ollama /api/chat streaming format so the HTML widget
4
- needs only a URL change to work.
5
-
6
- Adapter files should live in ./adapter/ inside this Space repo.
7
- Base model: meta-llama/Llama-3.2-1B
8
  """
9
 
10
  import json
11
  import asyncio
12
  import threading
 
13
  from pathlib import Path
14
  from typing import Iterator
15
 
@@ -20,22 +16,14 @@ from fastapi.responses import StreamingResponse, HTMLResponse
20
  from peft import PeftModel
21
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
22
 
23
- # ── Config ───────────────────────────────────────────────────────────────────
24
  BASE_MODEL = "meta-llama/Llama-3.2-1B"
25
  ADAPTER_PATH = Path(__file__).parent / "adapter"
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
27
 
28
- # ── App ───────────────────────────────────────────────────────────────────────
29
  app = FastAPI(title="Codette AI")
 
30
 
31
- app.add_middleware(
32
- CORSMiddleware,
33
- allow_origins=["*"], # Squarespace domain β€” keep open so the widget works
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
-
38
- # ── Model (loaded once at startup) ────────────────────────────────────────────
39
  print(f"Loading tokenizer from {ADAPTER_PATH} …")
40
  tokenizer = AutoTokenizer.from_pretrained(str(ADAPTER_PATH))
41
  if tokenizer.pad_token is None:
@@ -47,24 +35,21 @@ base = AutoModelForCausalLM.from_pretrained(
47
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
48
  device_map="auto",
49
  low_cpu_mem_usage=True,
 
50
  )
51
 
52
- print(f"Loading LoRA adapter from {ADAPTER_PATH} …")
53
  model = PeftModel.from_pretrained(base, str(ADAPTER_PATH))
54
-
55
- print("Merging LoRA weights into base model …")
56
- model = model.merge_and_unload() # ← this is the actual merge step
57
  model.eval()
58
  print(f"βœ… Model ready on {DEVICE}")
59
 
60
 
61
- # ── Helpers ───────────────────────────────────────────────────────────────────
62
- def build_prompt(messages: list[dict]) -> str:
63
- """Convert OpenAI-style messages to a simple Llama-3.2 instruct prompt."""
64
  parts = []
65
  for m in messages:
66
- role = m.get("role", "user")
67
- content = m.get("content", "")
68
  if role == "system":
69
  parts.append(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{content}<|eot_id|>")
70
  elif role == "user":
@@ -75,33 +60,17 @@ def build_prompt(messages: list[dict]) -> str:
75
  return "".join(parts)
76
 
77
 
78
- def stream_tokens(messages: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
79
- prompt = build_prompt(messages)
80
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
81
-
82
- streamer = TextIteratorStreamer(
83
- tokenizer, skip_prompt=True, skip_special_tokens=True
84
- )
85
-
86
- gen_kwargs = dict(
87
- **inputs,
88
- max_new_tokens=max_new_tokens,
89
- do_sample=True,
90
- temperature=0.7,
91
- top_p=0.9,
92
- streamer=streamer,
93
- )
94
-
95
- thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
96
  thread.start()
97
-
98
  for token in streamer:
99
  yield token
100
-
101
  thread.join()
102
 
103
 
104
- # ── Routes ────────────────────────────────────────────────────────────────────
105
  @app.get("/", response_class=HTMLResponse)
106
  async def root():
107
  return "<h2>Codette AI is running βœ…</h2><p>POST /api/chat to chat.</p>"
@@ -111,36 +80,11 @@ async def root():
111
  async def chat(request: Request):
112
  body = await request.json()
113
  messages = body.get("messages", [])
114
- stream = body.get("stream", True)
115
 
116
- if not stream:
117
- # Non-streaming β€” collect everything first
118
- full = "".join(stream_tokens(messages))
119
- return {
120
- "message": {"role": "assistant", "content": full},
121
- "done": True,
122
- }
123
-
124
- # Streaming β€” mimic Ollama's NDJSON format exactly
125
  async def event_stream():
126
- full = ""
127
  for token in stream_tokens(messages):
128
- full += token
129
- chunk = json.dumps({
130
- "message": {"role": "assistant", "content": token},
131
- "done": False,
132
- })
133
- yield chunk + "\n"
134
- await asyncio.sleep(0) # yield control to event loop
135
-
136
- # Final message with done=true
137
- yield json.dumps({
138
- "message": {"role": "assistant", "content": ""},
139
- "done": True,
140
- }) + "\n"
141
-
142
- return StreamingResponse(
143
- event_stream(),
144
- media_type="application/x-ndjson",
145
- headers={"X-Accel-Buffering": "no"},
146
- )
 
1
  """
2
  Codette AI Space β€” FastAPI + streaming chat API
 
 
 
 
 
3
  """
4
 
5
  import json
6
  import asyncio
7
  import threading
8
+ import os
9
  from pathlib import Path
10
  from typing import Iterator
11
 
 
16
  from peft import PeftModel
17
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
18
 
 
19
  BASE_MODEL = "meta-llama/Llama-3.2-1B"
20
  ADAPTER_PATH = Path(__file__).parent / "adapter"
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set as a Secret in Space settings
23
 
 
24
  app = FastAPI(title="Codette AI")
25
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
26
 
 
 
 
 
 
 
 
 
27
  print(f"Loading tokenizer from {ADAPTER_PATH} …")
28
  tokenizer = AutoTokenizer.from_pretrained(str(ADAPTER_PATH))
29
  if tokenizer.pad_token is None:
 
35
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
36
  device_map="auto",
37
  low_cpu_mem_usage=True,
38
+ token=HF_TOKEN,
39
  )
40
 
41
+ print(f"Loading LoRA adapter …")
42
  model = PeftModel.from_pretrained(base, str(ADAPTER_PATH))
43
+ print("Merging LoRA weights …")
44
+ model = model.merge_and_unload()
 
45
  model.eval()
46
  print(f"βœ… Model ready on {DEVICE}")
47
 
48
 
49
+ def build_prompt(messages):
 
 
50
  parts = []
51
  for m in messages:
52
+ role, content = m.get("role", "user"), m.get("content", "")
 
53
  if role == "system":
54
  parts.append(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{content}<|eot_id|>")
55
  elif role == "user":
 
60
  return "".join(parts)
61
 
62
 
63
+ def stream_tokens(messages, max_new_tokens=512):
64
+ inputs = tokenizer(build_prompt(messages), return_tensors="pt").to(DEVICE)
65
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
66
+ thread = threading.Thread(target=model.generate, kwargs=dict(**inputs, max_new_tokens=max_new_tokens,
67
+ do_sample=True, temperature=0.7, top_p=0.9, streamer=streamer))
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  thread.start()
 
69
  for token in streamer:
70
  yield token
 
71
  thread.join()
72
 
73
 
 
74
  @app.get("/", response_class=HTMLResponse)
75
  async def root():
76
  return "<h2>Codette AI is running βœ…</h2><p>POST /api/chat to chat.</p>"
 
80
  async def chat(request: Request):
81
  body = await request.json()
82
  messages = body.get("messages", [])
 
83
 
 
 
 
 
 
 
 
 
 
84
  async def event_stream():
 
85
  for token in stream_tokens(messages):
86
+ yield json.dumps({"message": {"role": "assistant", "content": token}, "done": False}) + "\n"
87
+ await asyncio.sleep(0)
88
+ yield json.dumps({"message": {"role": "assistant", "content": ""}, "done": True}) + "\n"
89
+
90
+ return StreamingResponse(event_stream(), media_type="application/x-ndjson", headers={"X-Accel-Buffering": "no"})