ghosthets commited on
Commit
2aa19f3
·
verified ·
1 Parent(s): 05fe403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -49
app.py CHANGED
@@ -1,30 +1,36 @@
1
  import flask
2
  from flask import request, jsonify
3
- # Use AutoModelForCausalLM for Decoder-only models like Qwen
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
7
- # Initialize the Flask application
8
  app = flask.Flask(__name__)
9
 
10
- # Qwen1.5-0.5B-Chat Model ID
11
- model_id = "Qwen/Qwen1.5-0.5B-Chat"
 
 
 
12
 
13
  print(f"🔄 Loading {model_id} model...")
14
 
15
- # Load the tokenizer
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
 
18
- # Load the model using the correct CausalLM class
19
- # Using bfloat16 for better memory/speed if a compatible GPU is available
20
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
 
 
21
 
22
- # Set the device (GPU/CPU)
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model.to(device)
25
 
26
- print(f"✅ {model_id} Model loaded successfully!")
27
 
 
 
 
28
  @app.route('/chat', methods=['POST'])
29
  def chat():
30
  try:
@@ -34,57 +40,31 @@ def chat():
34
  if not msg:
35
  return jsonify({"error": "No message sent"}), 400
36
 
37
- # --- Qwen1.5 Chat Template Formatting ---
38
- # Qwen models require input in the ChatML format.
39
- chat_history = [{"role": "user", "content": msg}]
40
-
41
- # apply_chat_template handles the specific formatting (e.g., <|im_start|>user\n...)
42
- formatted_prompt = tokenizer.apply_chat_template(
43
- chat_history,
44
- tokenize=False,
45
- add_generation_prompt=True
46
- )
47
-
48
- # Tokenize the formatted prompt
49
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
50
-
51
- # Generation configuration
52
  output = model.generate(
53
  **inputs,
54
- max_length=256,
55
  do_sample=True,
56
- top_p=0.8,
57
  temperature=0.6,
58
- # Set pad_token_id to eos_token_id, which is often necessary for Causal LMs
59
- pad_token_id=tokenizer.eos_token_id
60
  )
61
 
62
- # Decode the full output
63
- full_reply = tokenizer.decode(output[0], skip_special_tokens=False)
64
-
65
- # --- Extract only the Generated Response ---
66
-
67
- # Qwen ChatML format uses '<|im_start|>assistant\n' before the response
68
- assistant_tag = "<|im_start|>assistant\n"
69
-
70
- if assistant_tag in full_reply:
71
- # Split the full reply and take the content after the assistant tag
72
- reply = full_reply.split(assistant_tag)[-1].strip()
73
-
74
- # Remove the end-of-message tag if it was generated
75
- if "<|im_end|>" in reply:
76
- reply = reply.split("<|im_end|>")[0].strip()
77
- else:
78
- # Fallback: Decode only the newly generated tokens
79
- reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
80
 
81
  return jsonify({"reply": reply})
82
 
83
  except Exception as e:
84
- # Catch any runtime errors
85
  return jsonify({"error": str(e)}), 500
86
 
87
 
88
  if __name__ == "__main__":
89
- # Run the Flask app
90
  app.run(host='0.0.0.0', port=7860)
 
1
  import flask
2
  from flask import request, jsonify
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
 
6
  app = flask.Flask(__name__)
7
 
8
+ # ---------------------------
9
+ # SMALL LLM MODEL (1–2 GB)
10
+ # ---------------------------
11
+ # Best small model: SmolLM-1.7B-Chat
12
+ model_id = "HuggingFaceTB/SmolLM-1.7B-Chat"
13
 
14
  print(f"🔄 Loading {model_id} model...")
15
 
16
+ # Load tokenizer
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
+ # Load model (auto dtype to avoid errors)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_id,
22
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
+ )
24
 
25
+ # Device setup
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  model.to(device)
28
 
29
+ print(f"✅ {model_id} loaded successfully!")
30
 
31
+ # ---------------------------
32
+ # Chat Endpoint
33
+ # ---------------------------
34
  @app.route('/chat', methods=['POST'])
35
  def chat():
36
  try:
 
40
  if not msg:
41
  return jsonify({"error": "No message sent"}), 400
42
 
43
+ # SmolLM uses normal text prompt (no ChatML)
44
+ prompt = f"<|user|>\n{msg}\n<|assistant|>\n"
45
+
46
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
47
+
 
 
 
 
 
 
 
 
 
 
48
  output = model.generate(
49
  **inputs,
50
+ max_new_tokens=256,
51
  do_sample=True,
 
52
  temperature=0.6,
53
+ top_p=0.8,
54
+ pad_token_id=tokenizer.eos_token_id,
55
  )
56
 
57
+ reply = tokenizer.decode(output[0], skip_special_tokens=True)
58
+
59
+ # Extract only assistant part
60
+ if "<|assistant|>" in reply:
61
+ reply = reply.split("<|assistant|>")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  return jsonify({"reply": reply})
64
 
65
  except Exception as e:
 
66
  return jsonify({"error": str(e)}), 500
67
 
68
 
69
  if __name__ == "__main__":
 
70
  app.run(host='0.0.0.0', port=7860)