triflix commited on
Commit
2a263c0
·
verified ·
1 Parent(s): 221c179

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +105 -41
main.py CHANGED
@@ -1,83 +1,147 @@
 
1
  from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import List, Optional, Dict, Any
 
 
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
- import datetime
7
 
8
- # 1. Initialize App
9
- app = FastAPI(title="FunctionGemma Brain API")
 
10
 
11
- # 2. Global Variables for Model (Loaded on Startup)
12
  MODEL_ID = "google/functiongemma-270m-it"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  tokenizer = None
14
  model = None
15
 
16
- # 3. Request Schema
17
- # This is what your Go Backend will send to this Python Service
 
 
 
18
  class ChatRequest(BaseModel):
19
- query: str
20
- tools: List[Dict[str, Any]] # The JSON schema of tools
21
- include_date: bool = True # Option to inject today's date
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # 4. Load Model on Startup
24
  @app.on_event("startup")
25
- async def load_model():
26
  global tokenizer, model
27
- print("🧠 Loading FunctionGemma 270M...")
 
 
 
 
 
 
 
 
28
  try:
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
30
- # Run on CPU (It's fast enough for 270M)
31
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu")
32
- print("✅ Model Loaded Successfully!")
 
 
33
  except Exception as e:
34
- print(f" Failed to load model: {e}")
 
 
 
 
 
35
 
36
- # 5. The Endpoint
37
  @app.post("/generate")
38
  async def generate_function_call(request: ChatRequest):
39
- global tokenizer, model
40
-
41
- if not model or not tokenizer:
42
- raise HTTPException(status_code=503, detail="Model not loaded yet")
43
 
44
  try:
45
- # A. Prepare System Prompt
46
- today = datetime.date.today().strftime("%Y-%m-%d")
47
- system_content = "You are a model that can do function calling with the following functions."
 
48
  if request.include_date:
 
49
  system_content += f" Today is {today}."
50
 
51
- # B. Construct Messages
52
  messages = [
53
  {"role": "system", "content": system_content},
54
- {"role": "user", "content": request.query}
55
  ]
56
 
57
- # C. Apply Chat Template (This handles the JSON Schema formatting automatically)
58
  inputs = tokenizer.apply_chat_template(
59
  messages,
60
  tools=request.tools,
61
  add_generation_prompt=True,
 
62
  return_dict=True,
63
- return_tensors="pt"
64
  )
65
 
66
- # D. Generate
67
- # We limit tokens because we only want the function call, not a long story
68
- outputs = model.generate(**inputs, max_new_tokens=128)
69
-
70
- # E. Decode
71
- # We skip the input tokens to only get the new generated text
72
- generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 
 
 
73
 
74
  return {"response": generated_text}
75
 
76
  except Exception as e:
77
- print(f"Error during generation: {e}")
78
  raise HTTPException(status_code=500, detail=str(e))
79
 
80
- # Health check endpoint
81
- @app.get("/")
82
  def health_check():
83
- return {"status": "running", "model": MODEL_ID}
 
 
 
 
 
1
+ # app.py
2
  from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Dict, Any
5
+ import os
6
+ import datetime
7
+
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from huggingface_hub import login, HfHubHTTPError
11
 
12
+ # ==========================================
13
+ # 1. CONFIGURATION (Secure Defaults)
14
+ # ==========================================
15
 
 
16
  MODEL_ID = "google/functiongemma-270m-it"
17
+ HF_TOKEN_ENV = "HF_TOKEN"
18
+
19
+ def get_hf_token() -> str:
20
+ """
21
+ Fetch Hugging Face token from environment.
22
+
23
+ Raises:
24
+ RuntimeError: if token is missing
25
+ """
26
+ token = os.getenv(HF_TOKEN_ENV)
27
+ if not token:
28
+ raise RuntimeError(
29
+ f"Missing required environment variable: {HF_TOKEN_ENV}"
30
+ )
31
+ return token
32
+
33
+
34
+ # ==========================================
35
+ # 2. APP SETUP
36
+ # ==========================================
37
+
38
+ app = FastAPI(
39
+ title="FunctionGemma Brain API",
40
+ version="1.0.0",
41
+ )
42
+
43
  tokenizer = None
44
  model = None
45
 
46
+
47
+ # ==========================================
48
+ # 3. DATA MODELS
49
+ # ==========================================
50
+
51
  class ChatRequest(BaseModel):
52
+ """
53
+ Request schema for function-call generation.
54
+ """
55
+ query: str = Field(..., min_length=1, max_length=4096)
56
+ tools: List[Dict[str, Any]]
57
+ include_date: bool = True
58
+
59
+
60
+ class HealthResponse(BaseModel):
61
+ status: str
62
+ model: str
63
+ auth: str
64
+
65
+
66
+ # ==========================================
67
+ # 4. STARTUP (Auth + Load Model)
68
+ # ==========================================
69
 
 
70
  @app.on_event("startup")
71
+ async def startup():
72
  global tokenizer, model
73
+
74
+ # A. Authenticate (fail-fast)
75
+ try:
76
+ hf_token = get_hf_token()
77
+ login(token=hf_token)
78
+ except (RuntimeError, HfHubHTTPError) as e:
79
+ raise RuntimeError(f"Hugging Face authentication failed: {e}")
80
+
81
+ # B. Load Model
82
  try:
83
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ MODEL_ID,
86
+ device_map="cpu",
87
+ torch_dtype=torch.float32,
88
+ )
89
  except Exception as e:
90
+ raise RuntimeError(f"Model load failed: {e}")
91
+
92
+
93
+ # ==========================================
94
+ # 5. API ENDPOINT
95
+ # ==========================================
96
 
 
97
  @app.post("/generate")
98
  async def generate_function_call(request: ChatRequest):
99
+ if model is None or tokenizer is None:
100
+ raise HTTPException(status_code=503, detail="Model not ready")
 
 
101
 
102
  try:
103
+ # System context
104
+ system_content = (
105
+ "You are a model that can do function calling with the following functions."
106
+ )
107
  if request.include_date:
108
+ today = datetime.date.today().isoformat()
109
  system_content += f" Today is {today}."
110
 
 
111
  messages = [
112
  {"role": "system", "content": system_content},
113
+ {"role": "user", "content": request.query},
114
  ]
115
 
 
116
  inputs = tokenizer.apply_chat_template(
117
  messages,
118
  tools=request.tools,
119
  add_generation_prompt=True,
120
+ return_tensors="pt",
121
  return_dict=True,
 
122
  )
123
 
124
+ outputs = model.generate(
125
+ **inputs,
126
+ max_new_tokens=128,
127
+ do_sample=False, # deterministic
128
+ )
129
+
130
+ generated_text = tokenizer.decode(
131
+ outputs[0][len(inputs["input_ids"][0]):],
132
+ skip_special_tokens=True,
133
+ )
134
 
135
  return {"response": generated_text}
136
 
137
  except Exception as e:
 
138
  raise HTTPException(status_code=500, detail=str(e))
139
 
140
+
141
+ @app.get("/", response_model=HealthResponse)
142
  def health_check():
143
+ return {
144
+ "status": "running",
145
+ "model": MODEL_ID,
146
+ "auth": "env",
147
+ }