Proff12 commited on
Commit
958f33b
·
verified ·
1 Parent(s): 636350e

Update backend/app/main.py

Browse files
Files changed (1) hide show
  1. backend/app/main.py +164 -165
backend/app/main.py CHANGED
@@ -1,172 +1,171 @@
1
-
2
- import os
3
- from typing import List, Literal, Optional
4
-
5
- from fastapi import FastAPI, HTTPException
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from fastapi.staticfiles import StaticFiles
8
- from pydantic import BaseModel
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
- import torch
11
-
12
- APP_TITLE = "HF Chat (Fathom-R1-14B)"
13
- APP_VERSION = "0.2.0"
14
-
15
- # ---- Config via ENV ----
16
- MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
17
- PIPELINE_TASK = os.getenv("PIPELINE_TASK", "text-generation")
18
- MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "8192")) # keep prompt reasonable
19
- STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
20
- ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
21
- QUANTIZE = os.getenv("QUANTIZE", "auto") # auto|4bit|8bit|none
22
-
23
- app = FastAPI(title=APP_TITLE, version=APP_VERSION)
24
-
25
- if ALLOWED_ORIGINS:
26
- origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=origins,
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- class Message(BaseModel):
36
- role: Literal["system", "user", "assistant"]
37
- content: str
38
-
39
- class ChatRequest(BaseModel):
40
- messages: List[Message]
41
- max_new_tokens: int = 512
42
- temperature: float = 0.7
43
- top_p: float = 0.95
44
- repetition_penalty: Optional[float] = 1.0
45
- stop: Optional[List[str]] = None
46
-
47
- class ChatResponse(BaseModel):
48
- reply: str
49
- model: str
50
-
51
- tokenizer = None
52
- model = None
53
- generator = None
54
-
55
- def load_pipeline():
56
- global tokenizer, model, generator
57
- device = "cuda" if torch.cuda.is_available() else "cpu"
58
-
59
- # Load tokenizer
60
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
61
- if tokenizer.pad_token is None and tokenizer.eos_token is not None:
62
- tokenizer.pad_token = tokenizer.eos_token
63
-
64
- # Determine load strategy
65
- load_kwargs = {}
66
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
67
-
68
- if device == "cuda":
69
- # try quantization if requested
70
- if QUANTIZE.lower() in ("4bit", "8bit", "auto"):
71
- try:
72
- import bitsandbytes as bnb # noqa: F401
73
- if QUANTIZE.lower() == "8bit":
74
- load_kwargs.update(dict(load_in_8bit=True))
75
- else:
76
- # 4bit or auto (prefer 4bit)
77
- load_kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16))
78
- except Exception:
79
- # bitsandbytes not available; fall back to full precision on GPU
80
- pass
81
- load_kwargs.setdefault("torch_dtype", dtype)
82
- load_kwargs.setdefault("device_map", "auto")
83
- else:
84
- # CPU fallback
85
- load_kwargs.setdefault("torch_dtype", dtype)
86
-
87
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
88
-
89
- generator = pipeline(
90
- PIPELINE_TASK,
91
- model=model,
92
- tokenizer=tokenizer,
93
- device_map=load_kwargs.get("device_map", None) or (0 if device == "cuda" else -1),
94
- )
95
-
96
- @app.on_event("startup")
97
- def _startup():
98
- load_pipeline()
99
-
100
- def messages_to_prompt(messages: List[Message]) -> str:
101
- """
102
- Prefer tokenizer chat template (Qwen-based models ship one). Fallback to a simple transcript.
103
- """
104
  try:
105
- # Convert to HF chat format: list of dicts with role/content
106
- chat = [{"role": m.role, "content": m.content} for m in messages]
107
- return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
 
 
 
108
  except Exception:
109
- # Fallback formatting
110
- parts = []
111
- for m in messages:
112
- if m.role == "system":
113
- parts.append(f"System: {m.content}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  ")
115
- elif m.role == "user":
116
- parts.append(f"User: {m.content}
117
  ")
118
- else:
119
- parts.append(f"Assistant: {m.content}
120
  ")
121
- parts.append("Assistant:")
122
- return "
123
  ".join(parts)
124
 
125
- def truncate_prompt(prompt: str, max_tokens: int) -> str:
126
- ids = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0]
127
- if len(ids) <= max_tokens:
128
- return prompt
129
- trimmed = ids[-max_tokens:]
130
- return tokenizer.decode(trimmed, skip_special_tokens=True)
131
-
132
- @app.get("/api/health")
133
- def health():
134
- device = next(model.parameters()).device.type if model is not None else "N/A"
135
- return {"status": "ok", "model": MODEL_ID, "task": PIPELINE_TASK, "device": device}
136
-
137
- @app.post("/api/chat", response_model=ChatResponse)
138
- def chat(req: ChatRequest):
139
- if generator is None:
140
- raise HTTPException(status_code=503, detail="Model not loaded")
141
- if not req.messages:
142
- raise HTTPException(status_code=400, detail="messages cannot be empty")
143
-
144
- raw_prompt = messages_to_prompt(req.messages)
145
- prompt = truncate_prompt(raw_prompt, MAX_INPUT_TOKENS)
146
-
147
- gen_kwargs = {
148
- "max_new_tokens": req.max_new_tokens,
149
- "do_sample": req.temperature > 0,
150
- "temperature": req.temperature,
151
- "top_p": req.top_p,
152
- "repetition_penalty": req.repetition_penalty,
153
- "eos_token_id": tokenizer.eos_token_id,
154
- "pad_token_id": tokenizer.pad_token_id,
155
- "return_full_text": True,
156
- }
157
- if req.stop:
158
- gen_kwargs["stop"] = req.stop
159
-
160
- outputs = generator(prompt, **gen_kwargs)
161
- if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]:
162
- full = outputs[0]["generated_text"]
163
- reply = full[len(prompt):].strip() if full.startswith(prompt) else full
164
- else:
165
- reply = str(outputs)
166
- if not reply:
167
- reply = "(No response generated.)"
168
- return ChatResponse(reply=reply, model=MODEL_ID)
169
-
170
- # Serve frontend build (if present)
171
- if os.path.isdir(STATIC_DIR):
172
- app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")
 
1
+ import os
2
+ from typing import List, Literal, Optional
3
+
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.staticfiles import StaticFiles
7
+ from pydantic import BaseModel
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+ import torch
10
+
11
+ APP_TITLE = "HF Chat (Fathom-R1-14B)"
12
+ APP_VERSION = "0.2.0"
13
+
14
+ # ---- Config via ENV ----
15
+ MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
16
+ PIPELINE_TASK = os.getenv("PIPELINE_TASK", "text-generation")
17
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "8192")) # keep prompt reasonable
18
+ STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
19
+ ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
20
+ QUANTIZE = os.getenv("QUANTIZE", "auto") # auto|4bit|8bit|none
21
+
22
+ app = FastAPI(title=APP_TITLE, version=APP_VERSION)
23
+
24
+ if ALLOWED_ORIGINS:
25
+ origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=origins,
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ class Message(BaseModel):
35
+ role: Literal["system", "user", "assistant"]
36
+ content: str
37
+
38
+ class ChatRequest(BaseModel):
39
+ messages: List[Message]
40
+ max_new_tokens: int = 512
41
+ temperature: float = 0.7
42
+ top_p: float = 0.95
43
+ repetition_penalty: Optional[float] = 1.0
44
+ stop: Optional[List[str]] = None
45
+
46
+ class ChatResponse(BaseModel):
47
+ reply: str
48
+ model: str
49
+
50
+ tokenizer = None
51
+ model = None
52
+ generator = None
53
+
54
+ def load_pipeline():
55
+ global tokenizer, model, generator
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ # Load tokenizer
59
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
60
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+
63
+ # Determine load strategy
64
+ load_kwargs = {}
65
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
66
+
67
+ if device == "cuda":
68
+ # try quantization if requested
69
+ if QUANTIZE.lower() in ("4bit", "8bit", "auto"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  try:
71
+ import bitsandbytes as bnb # noqa: F401
72
+ if QUANTIZE.lower() == "8bit":
73
+ load_kwargs.update(dict(load_in_8bit=True))
74
+ else:
75
+ # 4bit or auto (prefer 4bit)
76
+ load_kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16))
77
  except Exception:
78
+ # bitsandbytes not available; fall back to full precision on GPU
79
+ pass
80
+ load_kwargs.setdefault("torch_dtype", dtype)
81
+ load_kwargs.setdefault("device_map", "auto")
82
+ else:
83
+ # CPU fallback
84
+ load_kwargs.setdefault("torch_dtype", dtype)
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
87
+
88
+ generator = pipeline(
89
+ PIPELINE_TASK,
90
+ model=model,
91
+ tokenizer=tokenizer,
92
+ device_map=load_kwargs.get("device_map", None) or (0 if device == "cuda" else -1),
93
+ )
94
+
95
+ @app.on_event("startup")
96
+ def _startup():
97
+ load_pipeline()
98
+
99
+ def messages_to_prompt(messages: List[Message]) -> str:
100
+ """
101
+ Prefer tokenizer chat template (Qwen-based models ship one). Fallback to a simple transcript.
102
+ """
103
+ try:
104
+ # Convert to HF chat format: list of dicts with role/content
105
+ chat = [{"role": m.role, "content": m.content} for m in messages]
106
+ return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
107
+ except Exception:
108
+ # Fallback formatting
109
+ parts = []
110
+ for m in messages:
111
+ if m.role == "system":
112
+ parts.append(f"System: {m.content}
113
  ")
114
+ elif m.role == "user":
115
+ parts.append(f"User: {m.content}
116
  ")
117
+ else:
118
+ parts.append(f"Assistant: {m.content}
119
  ")
120
+ parts.append("Assistant:")
121
+ return "
122
  ".join(parts)
123
 
124
+ def truncate_prompt(prompt: str, max_tokens: int) -> str:
125
+ ids = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0]
126
+ if len(ids) <= max_tokens:
127
+ return prompt
128
+ trimmed = ids[-max_tokens:]
129
+ return tokenizer.decode(trimmed, skip_special_tokens=True)
130
+
131
+ @app.get("/api/health")
132
+ def health():
133
+ device = next(model.parameters()).device.type if model is not None else "N/A"
134
+ return {"status": "ok", "model": MODEL_ID, "task": PIPELINE_TASK, "device": device}
135
+
136
+ @app.post("/api/chat", response_model=ChatResponse)
137
+ def chat(req: ChatRequest):
138
+ if generator is None:
139
+ raise HTTPException(status_code=503, detail="Model not loaded")
140
+ if not req.messages:
141
+ raise HTTPException(status_code=400, detail="messages cannot be empty")
142
+
143
+ raw_prompt = messages_to_prompt(req.messages)
144
+ prompt = truncate_prompt(raw_prompt, MAX_INPUT_TOKENS)
145
+
146
+ gen_kwargs = {
147
+ "max_new_tokens": req.max_new_tokens,
148
+ "do_sample": req.temperature > 0,
149
+ "temperature": req.temperature,
150
+ "top_p": req.top_p,
151
+ "repetition_penalty": req.repetition_penalty,
152
+ "eos_token_id": tokenizer.eos_token_id,
153
+ "pad_token_id": tokenizer.pad_token_id,
154
+ "return_full_text": True,
155
+ }
156
+ if req.stop:
157
+ gen_kwargs["stop"] = req.stop
158
+
159
+ outputs = generator(prompt, **gen_kwargs)
160
+ if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]:
161
+ full = outputs[0]["generated_text"]
162
+ reply = full[len(prompt):].strip() if full.startswith(prompt) else full
163
+ else:
164
+ reply = str(outputs)
165
+ if not reply:
166
+ reply = "(No response generated.)"
167
+ return ChatResponse(reply=reply, model=MODEL_ID)
168
+
169
+ # Serve frontend build (if present)
170
+ if os.path.isdir(STATIC_DIR):
171
+ app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")