triflix commited on
Commit
e72da52
·
verified ·
1 Parent(s): e81c5d7

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +83 -0
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional, Dict, Any
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import datetime
7
+
8
+ # 1. Initialize App
9
+ app = FastAPI(title="FunctionGemma Brain API")
10
+
11
+ # 2. Global Variables for Model (Loaded on Startup)
12
+ MODEL_ID = "google/functiongemma-270m-it"
13
+ tokenizer = None
14
+ model = None
15
+
16
+ # 3. Request Schema
17
+ # This is what your Go Backend will send to this Python Service
18
+ class ChatRequest(BaseModel):
19
+ query: str
20
+ tools: List[Dict[str, Any]] # The JSON schema of tools
21
+ include_date: bool = True # Option to inject today's date
22
+
23
+ # 4. Load Model on Startup
24
+ @app.on_event("startup")
25
+ async def load_model():
26
+ global tokenizer, model
27
+ print("🧠 Loading FunctionGemma 270M...")
28
+ try:
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
30
+ # Run on CPU (It's fast enough for 270M)
31
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu")
32
+ print("✅ Model Loaded Successfully!")
33
+ except Exception as e:
34
+ print(f"❌ Failed to load model: {e}")
35
+
36
+ # 5. The Endpoint
37
+ @app.post("/generate")
38
+ async def generate_function_call(request: ChatRequest):
39
+ global tokenizer, model
40
+
41
+ if not model or not tokenizer:
42
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
43
+
44
+ try:
45
+ # A. Prepare System Prompt
46
+ today = datetime.date.today().strftime("%Y-%m-%d")
47
+ system_content = "You are a model that can do function calling with the following functions."
48
+ if request.include_date:
49
+ system_content += f" Today is {today}."
50
+
51
+ # B. Construct Messages
52
+ messages = [
53
+ {"role": "system", "content": system_content},
54
+ {"role": "user", "content": request.query}
55
+ ]
56
+
57
+ # C. Apply Chat Template (This handles the JSON Schema formatting automatically)
58
+ inputs = tokenizer.apply_chat_template(
59
+ messages,
60
+ tools=request.tools,
61
+ add_generation_prompt=True,
62
+ return_dict=True,
63
+ return_tensors="pt"
64
+ )
65
+
66
+ # D. Generate
67
+ # We limit tokens because we only want the function call, not a long story
68
+ outputs = model.generate(**inputs, max_new_tokens=128)
69
+
70
+ # E. Decode
71
+ # We skip the input tokens to only get the new generated text
72
+ generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
73
+
74
+ return {"response": generated_text}
75
+
76
+ except Exception as e:
77
+ print(f"Error during generation: {e}")
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
+ # Health check endpoint
81
+ @app.get("/")
82
+ def health_check():
83
+ return {"status": "running", "model": MODEL_ID}