Euryeth commited on
Commit
aaea5e8
·
verified ·
1 Parent(s): 7cc8d56

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +46 -45
api.py CHANGED
@@ -1,21 +1,19 @@
1
- from flask import Flask, request, jsonify
2
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
  from huggingface_hub import login
 
6
 
7
- # Authenticate with your Hugging Face token
8
  login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
9
 
10
- # Initialize Flask
11
- app = Flask(__name__)
12
 
13
- # Model settings
14
  model_name = "cerebras/btlm-3b-8k-chat"
15
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
  os.environ['HF_HOME'] = '/tmp/cache'
17
 
18
- # Load tokenizer and model
19
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_name,
@@ -24,60 +22,63 @@ model = AutoModelForCausalLM.from_pretrained(
24
  trust_remote_code=True
25
  )
26
 
27
- # Build text generation pipeline
28
  generator = pipeline(
29
  "text-generation",
30
  model=model,
31
  tokenizer=tokenizer,
32
  device_map="auto",
33
  torch_dtype=torch_dtype,
34
- pad_token_id=tokenizer.eos_token_id
 
35
  )
36
 
37
- # Route: /v1/chat/completions (OpenAI-compatible)
 
 
 
 
 
38
  @app.route("/v1/chat/completions", methods=["POST"])
39
  def chat_completions():
40
- try:
41
- data = request.get_json()
42
- messages = data.get("messages", [])
43
- max_tokens = data.get("max_tokens", 256)
44
- temperature = data.get("temperature", 0.7)
45
 
46
- # Extract user message from chat history
47
- prompt = ""
48
- for msg in messages:
49
- role = msg.get("role", "")
50
- content = msg.get("content", "")
51
- if role and content:
52
- prompt += f"{role.capitalize()}: {content}\n"
53
- prompt += "Assistant:"
54
 
55
- # Run generation
56
- output = generator(
57
- prompt,
58
- max_new_tokens=max_tokens,
59
- temperature=temperature,
60
- top_p=0.9,
61
- repetition_penalty=1.1,
62
- do_sample=True
63
- )
64
 
65
- # Extract and format response
66
- generated_text = output[0]['generated_text']
67
- assistant_reply = generated_text.replace(prompt, "").strip()
 
 
 
 
 
68
 
69
- return jsonify({
70
- "choices": [{
71
- "message": {
72
- "role": "assistant",
73
- "content": assistant_reply
74
- }
75
- }]
76
- })
77
 
78
- except Exception as e:
79
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
80
 
81
- # Run Flask app
82
  if __name__ == "__main__":
83
  app.run(host="0.0.0.0", port=8080)
 
 
1
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  import os
4
  from huggingface_hub import login
5
+ from flask import Flask, request, jsonify
6
 
7
+ # Authenticate with Hugging Face token from Secrets
8
  login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
9
 
10
+ API_TOKEN = os.getenv("HF_API_TOKEN") # Set this token in your Space Secrets
 
11
 
12
+ # Setup
13
  model_name = "cerebras/btlm-3b-8k-chat"
14
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
  os.environ['HF_HOME'] = '/tmp/cache'
16
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_name,
 
22
  trust_remote_code=True
23
  )
24
 
 
25
  generator = pipeline(
26
  "text-generation",
27
  model=model,
28
  tokenizer=tokenizer,
29
  device_map="auto",
30
  torch_dtype=torch_dtype,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ trust_remote_code=True
33
  )
34
 
35
+ app = Flask(__name__)
36
+
37
+ @app.route("/")
38
+ def home():
39
+ return "API is running"
40
+
41
  @app.route("/v1/chat/completions", methods=["POST"])
42
  def chat_completions():
43
+ auth_header = request.headers.get("Authorization", "")
44
+ if not auth_header.startswith("Bearer ") or auth_header.split(" ")[1] != API_TOKEN:
45
+ return jsonify({"error": "Unauthorized"}), 401
 
 
46
 
47
+ data = request.get_json()
48
+ messages = data.get("messages", [])
49
+ max_tokens = data.get("max_tokens", 256)
50
+ temperature = data.get("temperature", 0.7)
 
 
 
 
51
 
52
+ prompt = ""
53
+ for msg in messages:
54
+ role = msg.get("role", "")
55
+ content = msg.get("content", "")
56
+ if role and content:
57
+ prompt += f"{role.capitalize()}: {content}\n"
58
+ prompt += "Assistant:"
 
 
59
 
60
+ output = generator(
61
+ prompt,
62
+ max_new_tokens=max_tokens,
63
+ temperature=temperature,
64
+ top_p=0.9,
65
+ repetition_penalty=1.1,
66
+ do_sample=True
67
+ )
68
 
69
+ generated_text = output[0]['generated_text']
70
+ assistant_reply = generated_text.replace(prompt, "").strip()
 
 
 
 
 
 
71
 
72
+ return jsonify({
73
+ "choices": [{
74
+ "message": {
75
+ "role": "assistant",
76
+ "content": assistant_reply
77
+ },
78
+ "finish_reason": "stop",
79
+ "index": 0
80
+ }]
81
+ })
82
 
 
83
  if __name__ == "__main__":
84
  app.run(host="0.0.0.0", port=8080)