Samarth Naik commited on
Commit
01fa9b6
·
1 Parent(s): 414d456

added init files

Browse files
Files changed (5) hide show
  1. Dockerfile +24 -0
  2. README.md +123 -4
  3. app.py +209 -0
  4. requirements.txt +7 -0
  5. test_api.py +60 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ curl \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for better caching
12
+ COPY requirements.txt .
13
+
14
+ # Install Python dependencies
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy application code
18
+ COPY . .
19
+
20
+ # Expose port
21
+ EXPOSE 5001
22
+
23
+ # Run the application
24
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,131 @@
 
 
1
  ---
2
  title: Llamamodel
3
  emoji: ⚡
4
  colorFrom: yellow
5
  colorTo: pink
6
- sdk: gradio
7
- sdk_version: 6.2.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Llama-3.1-8B-Instruct Flask API
2
+
3
  ---
4
  title: Llamamodel
5
  emoji: ⚡
6
  colorFrom: yellow
7
  colorTo: pink
8
+ sdk: docker
9
+ app_port: 5001
 
10
  pinned: false
11
  ---
12
 
13
+ A Flask web application that serves the Meta Llama-3.1-8B-Instruct model via a REST API.
14
+
15
+ ## Features
16
+
17
+ - RESTful API with `/compute` endpoint
18
+ - JSON input/output
19
+ - Configurable generation parameters
20
+ - Memory-optimized model loading with 8-bit quantization
21
+ - CORS support
22
+ - Error handling and logging
23
+
24
+ ## Deployment to Hugging Face Spaces
25
+
26
+ This application is configured to run on Hugging Face Spaces using Docker. Once pushed:
27
+
28
+ 1. The model will automatically load on startup
29
+ 2. The `/compute` endpoint will be available at your space URL
30
+ 3. Use POST requests with JSON payloads to generate responses
31
+
32
+ ## Local Development
33
+
34
+ 1. Install the required dependencies:
35
+ ```bash
36
+ pip install -r requirements.txt
37
+ ```
38
+
39
+ 2. Run the Flask application:
40
+ ```bash
41
+ python app.py
42
+ ```
43
+
44
+ The application will start on `http://localhost:5000` by default.
45
+
46
+ ## Usage
47
+
48
+ ### Health Check
49
+
50
+ ```bash
51
+ GET http://localhost:5000/
52
+ ```
53
+
54
+ Response:
55
+ ```json
56
+ {
57
+ "status": "success",
58
+ "message": "Llama-3.1-8B-Instruct Flask API is running",
59
+ "model_loaded": true
60
+ }
61
+ ```
62
+
63
+ ### Generate Response
64
+
65
+ ```bash
66
+ POST https://your-space-name-username.hf.space/compute
67
+ ```
68
+
69
+ Request body:
70
+ ```json
71
+ {
72
+ "prompt": "What is the capital of France?",
73
+ "max_length": 256,
74
+ "temperature": 0.7,
75
+ "top_p": 0.9
76
+ }
77
+ ```
78
+
79
+ Response:
80
+ ```json
81
+ {
82
+ "status": "success",
83
+ "prompt": "What is the capital of France?",
84
+ "response": "The capital of France is Paris...",
85
+ "parameters": {
86
+ "max_length": 256,
87
+ "temperature": 0.7,
88
+ "top_p": 0.9
89
+ }
90
+ }
91
+ ```
92
+
93
+ ### Parameters
94
+
95
+ - `prompt` (required): The input text prompt
96
+ - `max_length` (optional): Maximum length of generated response (default: 512)
97
+ - `temperature` (optional): Sampling temperature (default: 0.7)
98
+ - `top_p` (optional): Top-p sampling parameter (default: 0.9)
99
+
100
+ ## Testing
101
+
102
+ Run the test script to verify the API is working:
103
+
104
+ ```bash
105
+ python test_api.py
106
+ ```
107
+
108
+ ## Example with curl
109
+
110
+ ```bash
111
+ # Health check
112
+ curl http://localhost:5000/
113
+
114
+ # Generate response
115
+ curl -X POST http://localhost:5000/compute \
116
+ -H "Content-Type: application/json" \
117
+ -d '{"prompt": "Explain machine learning in simple terms"}'
118
+ ```
119
+
120
+ ## System Requirements
121
+
122
+ - Python 3.8+
123
+ - CUDA-capable GPU (recommended)
124
+ - At least 16GB RAM
125
+ - 20GB+ free disk space for model weights
126
+
127
+ ## Notes
128
+
129
+ - The model uses 8-bit quantization to reduce memory usage
130
+ - First request may take longer as the model initializes
131
+ - The application logs model loading progress and errors
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ import logging
6
+ import os
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ app = Flask(__name__)
13
+ CORS(app) # Enable CORS for all routes
14
+
15
+ # Global variables for model and tokenizer
16
+ model = None
17
+ tokenizer = None
18
+
19
+ def load_model():
20
+ """Load the Llama model and tokenizer"""
21
+ global model, tokenizer
22
+
23
+ try:
24
+ logger.info("Loading Llama-3.1-8B-Instruct model...")
25
+ model_name = "meta-llama/Llama-3.1-8B-Instruct"
26
+
27
+ # Load tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+
30
+ # Set pad token if not exists
31
+ if tokenizer.pad_token is None:
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ # Load model with optimizations
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ torch_dtype=torch.float16,
38
+ device_map="auto",
39
+ load_in_8bit=True, # Use 8-bit quantization to reduce memory usage
40
+ trust_remote_code=True
41
+ )
42
+
43
+ logger.info("Model loaded successfully!")
44
+
45
+ except Exception as e:
46
+ logger.error(f"Error loading model: {str(e)}")
47
+ raise e
48
+
49
+ def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9):
50
+ """Generate response using the loaded Llama model"""
51
+ global model, tokenizer
52
+
53
+ if model is None or tokenizer is None:
54
+ raise ValueError("Model not loaded. Please ensure the model is properly initialized.")
55
+
56
+ try:
57
+ # Format the prompt for Llama-3.1-Instruct
58
+ formatted_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
59
+
60
+ # Tokenize the input
61
+ inputs = tokenizer.encode(formatted_prompt, return_tensors="pt")
62
+
63
+ # Move to the same device as the model
64
+ inputs = inputs.to(model.device)
65
+
66
+ # Generate response
67
+ with torch.no_grad():
68
+ outputs = model.generate(
69
+ inputs,
70
+ max_length=len(inputs[0]) + max_length,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ do_sample=True,
74
+ pad_token_id=tokenizer.eos_token_id,
75
+ eos_token_id=tokenizer.eos_token_id,
76
+ repetition_penalty=1.1
77
+ )
78
+
79
+ # Decode the response
80
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+
82
+ # Extract only the assistant's response
83
+ if "<|start_header_id|>assistant<|end_header_id|>" in response:
84
+ response = response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
85
+
86
+ return response
87
+
88
+ except Exception as e:
89
+ logger.error(f"Error generating response: {str(e)}")
90
+ raise e
91
+
92
+ @app.route('/', methods=['GET'])
93
+ def home():
94
+ """Health check endpoint"""
95
+ return jsonify({
96
+ "status": "success",
97
+ "message": "Llama-3.1-8B-Instruct Flask API is running",
98
+ "model_loaded": model is not None and tokenizer is not None
99
+ })
100
+
101
+ @app.route('/compute', methods=['POST'])
102
+ def compute():
103
+ """Main endpoint to process prompts and return model responses"""
104
+ try:
105
+ # Check if model is loaded
106
+ if model is None or tokenizer is None:
107
+ return jsonify({
108
+ "status": "error",
109
+ "message": "Model not loaded. Please wait for initialization."
110
+ }), 503
111
+
112
+ # Get JSON data from request
113
+ data = request.get_json()
114
+
115
+ if not data:
116
+ return jsonify({
117
+ "status": "error",
118
+ "message": "No JSON data provided"
119
+ }), 400
120
+
121
+ # Extract prompt from JSON
122
+ prompt = data.get('prompt')
123
+
124
+ if not prompt:
125
+ return jsonify({
126
+ "status": "error",
127
+ "message": "No 'prompt' field found in JSON data"
128
+ }), 400
129
+
130
+ if not isinstance(prompt, str) or len(prompt.strip()) == 0:
131
+ return jsonify({
132
+ "status": "error",
133
+ "message": "Prompt must be a non-empty string"
134
+ }), 400
135
+
136
+ # Get optional parameters
137
+ max_length = data.get('max_length', 512)
138
+ temperature = data.get('temperature', 0.7)
139
+ top_p = data.get('top_p', 0.9)
140
+
141
+ # Validate parameters
142
+ if not isinstance(max_length, int) or max_length <= 0 or max_length > 2048:
143
+ max_length = 512
144
+
145
+ if not isinstance(temperature, (int, float)) or temperature <= 0 or temperature > 2:
146
+ temperature = 0.7
147
+
148
+ if not isinstance(top_p, (int, float)) or top_p <= 0 or top_p > 1:
149
+ top_p = 0.9
150
+
151
+ # Generate response
152
+ logger.info(f"Processing prompt: {prompt[:100]}...")
153
+ response = generate_response(prompt, max_length, temperature, top_p)
154
+
155
+ return jsonify({
156
+ "status": "success",
157
+ "prompt": prompt,
158
+ "response": response,
159
+ "parameters": {
160
+ "max_length": max_length,
161
+ "temperature": temperature,
162
+ "top_p": top_p
163
+ }
164
+ })
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error in compute endpoint: {str(e)}")
168
+ return jsonify({
169
+ "status": "error",
170
+ "message": f"Internal server error: {str(e)}"
171
+ }), 500
172
+
173
+ @app.errorhandler(404)
174
+ def not_found(error):
175
+ return jsonify({
176
+ "status": "error",
177
+ "message": "Endpoint not found"
178
+ }), 404
179
+
180
+ @app.errorhandler(500)
181
+ def internal_error(error):
182
+ return jsonify({
183
+ "status": "error",
184
+ "message": "Internal server error"
185
+ }), 500
186
+
187
+ if __name__ == '__main__':
188
+ # Load the model when starting the app
189
+ logger.info("Starting Flask application...")
190
+
191
+ try:
192
+ load_model()
193
+ logger.info("Application ready!")
194
+ logger.info("API endpoints:")
195
+ logger.info(" GET / - Health check")
196
+ logger.info(" POST /compute - Generate responses")
197
+
198
+ # Run the Flask app
199
+ port = int(os.environ.get('PORT', 5001))
200
+ app.run(
201
+ host='0.0.0.0',
202
+ port=port,
203
+ debug=False,
204
+ threaded=True
205
+ )
206
+
207
+ except Exception as e:
208
+ logger.error(f"Failed to start application: {str(e)}")
209
+ exit(1)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ flask==3.0.0
2
+ transformers==4.36.0
3
+ torch==2.1.0
4
+ accelerate==0.25.0
5
+ bitsandbytes==0.41.3
6
+ flask-cors==4.0.0
7
+ huggingface_hub
test_api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ # Test the Flask API
5
+ def test_api():
6
+ url = "http://localhost:5001/compute"
7
+
8
+ # Test data
9
+ test_data = {
10
+ "prompt": "What is the capital of France?",
11
+ "max_length": 256,
12
+ "temperature": 0.7,
13
+ "top_p": 0.9
14
+ }
15
+
16
+ try:
17
+ print("Testing the /compute endpoint...")
18
+ print(f"Sending prompt: {test_data['prompt']}")
19
+
20
+ response = requests.post(url, json=test_data)
21
+
22
+ if response.status_code == 200:
23
+ result = response.json()
24
+ print("\nResponse received successfully!")
25
+ print(f"Status: {result['status']}")
26
+ print(f"Response: {result['response']}")
27
+ else:
28
+ print(f"Error: {response.status_code}")
29
+ print(response.text)
30
+
31
+ except requests.exceptions.ConnectionError:
32
+ print("Error: Could not connect to the server. Make sure the Flask app is running on port 5001.")
33
+ except Exception as e:
34
+ print(f"Error: {str(e)}")
35
+
36
+ def test_health_check():
37
+ url = "http://localhost:5001/"
38
+
39
+ try:
40
+ print("Testing health check endpoint...")
41
+ response = requests.get(url)
42
+
43
+ if response.status_code == 200:
44
+ result = response.json()
45
+ print("Health check successful!")
46
+ print(json.dumps(result, indent=2))
47
+ else:
48
+ print(f"Error: {response.status_code}")
49
+ print(response.text)
50
+
51
+ except requests.exceptions.ConnectionError:
52
+ print("Error: Could not connect to the server. Make sure the Flask app is running on port 5001.")
53
+ except Exception as e:
54
+ print(f"Error: {str(e)}")
55
+
56
+ if __name__ == "__main__":
57
+ print("=== Flask API Test ===")
58
+ test_health_check()
59
+ print("\n" + "="*50 + "\n")
60
+ test_api()