airsltd commited on
Commit
5392a8d
·
verified ·
1 Parent(s): 108f42d

Upload 4 files

Browse files
Files changed (4) hide show
  1. .gitignore +55 -0
  2. Dockerfile +16 -0
  3. app.py +194 -0
  4. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment variables
2
+ .env
3
+ .env.local
4
+ .env.*.local
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # Virtual environments
29
+ venv/
30
+ ENV/
31
+ env/
32
+
33
+ # Model cache
34
+ my_model_cache/
35
+ *.bin
36
+ *.safetensors
37
+
38
+ # IDE
39
+ .vscode/
40
+ .idea/
41
+ *.swp
42
+ *.swo
43
+ *~
44
+
45
+ # OS
46
+ .DS_Store
47
+ Thumbs.db
48
+
49
+ # Logs
50
+ *.log
51
+ logs/
52
+
53
+ # Temporary files
54
+ *.tmp
55
+ *.temp
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Combined application that automatically downloads the model if needed and starts the FastAPI server.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Check if model exists, if not download it
11
+ def check_and_download_model():
12
+ """Check if model exists in cache, if not download it"""
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ from huggingface_hub import login
15
+
16
+ # 下一步测试 mlx-community/functiongemma-270m-it-4bit
17
+ # Use TinyLlama - a fully public model
18
+ # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
19
+ model_name = "unsloth/functiongemma-270m-it"
20
+ cache_dir = "./my_model_cache"
21
+
22
+ # Check if model already exists in cache
23
+ model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}"
24
+ snapshot_path = model_path / "snapshots"
25
+
26
+ if snapshot_path.exists() and any(snapshot_path.iterdir()):
27
+ print(f"✓ Model {model_name} already exists in cache")
28
+ return model_name, cache_dir
29
+
30
+ print(f"✗ Model {model_name} not found in cache")
31
+ print("Downloading model...")
32
+
33
+ # Login to Hugging Face (optional, for gated models)
34
+ token = os.getenv("HUGGINGFACE_TOKEN")
35
+ if token:
36
+ try:
37
+ print("Logging in to Hugging Face...")
38
+ login(token=token)
39
+ print("✓ HuggingFace login successful!")
40
+ except Exception as e:
41
+ print(f"⚠ Login failed: {e}")
42
+ print("Continuing without login (public models only)")
43
+ else:
44
+ print("ℹ No HUGGINGFACE_TOKEN set - using public models only")
45
+
46
+ try:
47
+ # Download tokenizer
48
+ print("Loading tokenizer...")
49
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
50
+ print("✓ Tokenizer loaded successfully!")
51
+
52
+ # Download model
53
+ print("Loading model...")
54
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
55
+ print("✓ Model loaded successfully!")
56
+
57
+ print(f"✓ Model and tokenizer downloaded successfully to {cache_dir}")
58
+ return model_name, cache_dir
59
+
60
+ except Exception as e:
61
+ print(f"✗ Error downloading model: {e}")
62
+ print("\nPossible reasons:")
63
+ print("1. Model requires authentication - set HUGGINGFACE_TOKEN in .env")
64
+ print("2. Model is gated and you don't have access")
65
+ print("3. Network connection issues")
66
+ sys.exit(1)
67
+
68
+ def main():
69
+ """Main function to start the application"""
70
+ print("=" * 60)
71
+ print("FunctionGemma FastAPI Server")
72
+ print("=" * 60)
73
+
74
+ # Check and download model if needed
75
+ model_name, cache_dir = check_and_download_model()
76
+
77
+ # Now import and start the FastAPI app
78
+ print("\nStarting FastAPI server...")
79
+
80
+ from fastapi import FastAPI
81
+ from transformers import pipeline
82
+
83
+ app = FastAPI(title="FunctionGemma API", version="1.0.0")
84
+
85
+ # Initialize pipeline
86
+ print(f"Initializing pipeline with {model_name}...")
87
+ pipe = pipeline("text-generation", model=model_name)
88
+ print("✓ Pipeline initialized successfully!")
89
+
90
+ @app.get("/")
91
+ def greet_json():
92
+ return {
93
+ "message": "FunctionGemma API is running!",
94
+ "model": model_name,
95
+ "status": "ready"
96
+ }
97
+
98
+ @app.get("/health")
99
+ def health_check():
100
+ return {"status": "healthy", "model": model_name}
101
+
102
+ @app.get("/generate")
103
+ def generate_text(prompt: str = "Who are you?"):
104
+ """Generate text using the model"""
105
+ messages = [{"role": "user", "content": prompt}]
106
+ result = pipe(messages, max_new_tokens=100)
107
+ return {"response": result[0]["generated_text"]}
108
+
109
+ @app.post("/chat")
110
+ def chat_completion(messages: list):
111
+ """Chat completion endpoint"""
112
+ result = pipe(messages, max_new_tokens=200)
113
+ return {"response": result[0]["generated_text"]}
114
+
115
+ @app.post("/v1/chat/completions")
116
+ def openai_chat_completions(request: dict):
117
+ print('\n\n request')
118
+ print(request)
119
+ """
120
+ OpenAI-compatible chat completions endpoint
121
+ Expected request format:
122
+ {
123
+ "model": "google/gemma-2b-it",
124
+ "messages": [
125
+ {"role": "user", "content": "Hello"}
126
+ ],
127
+ "max_tokens": 100,
128
+ "temperature": 0.7
129
+ }
130
+ """
131
+ import time
132
+
133
+ messages = request.get("messages", [])
134
+ model = request.get("model", model_name)
135
+ max_tokens = request.get("max_tokens", 100)
136
+ temperature = request.get("temperature", 0.7)
137
+ print('\n\n messages')
138
+ print(messages)
139
+ print('\n\n model')
140
+ print(model)
141
+ print('\n\n max_tokens')
142
+ print(max_tokens)
143
+ print('\n\n temperature')
144
+ print(temperature)
145
+
146
+ # Generate response
147
+ result = pipe(
148
+ messages,
149
+ max_new_tokens=max_tokens,
150
+ # temperature=temperature
151
+ )
152
+ print('asdfasdfasdfasdf')
153
+
154
+ completion_id = f"chatcmpl-{int(time.time())}"
155
+ created = int(time.time())
156
+
157
+ return {
158
+ "id": completion_id,
159
+ "object": "chat.completion",
160
+ "created": created,
161
+ "model": model,
162
+ "choices": [
163
+ {
164
+ "index": 0,
165
+ "message": {
166
+ "role": "assistant",
167
+ "content": result[0]["generated_text"]
168
+ },
169
+ "finish_reason": "stop"
170
+ }
171
+ ],
172
+ "usage": {
173
+ "prompt_tokens": 0, # Would need tokenizer to calculate
174
+ "completion_tokens": 0,
175
+ "total_tokens": 0
176
+ }
177
+ }
178
+
179
+ # Run the server
180
+ import uvicorn
181
+ print("\n" + "=" * 60)
182
+ print("Server starting at http://localhost:8000")
183
+ print("Available endpoints:")
184
+ print(" GET / - Welcome message")
185
+ print(" GET /health - Health check")
186
+ print(" GET /generate?prompt=... - Generate text with prompt")
187
+ print(" POST /chat - Chat completion")
188
+ print(" POST /v1/chat/completions - OpenAI-compatible endpoint")
189
+ print("=" * 60 + "\n")
190
+
191
+ uvicorn.run(app, host="0.0.0.0", port=7860)
192
+
193
+ if __name__ == "__main__":
194
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ transformers
4
+ huggingface_hub
5
+ torch
6
+ accelerate