arthu1 commited on
Commit
1eb395b
·
verified ·
1 Parent(s): 6e3cfcf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ app = Flask(__name__)
6
+
7
+ # Load your model (can be any HF model)
8
+ MODEL_NAME = "tiiuae/falcon-7b-instruct"
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_NAME,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto"
14
+ )
15
+
16
+ @app.route('/setToken', methods=['POST'])
17
+ def set_token():
18
+ """
19
+ Main multimodal API endpoint.
20
+ Handles system + user prompts and returns generated response.
21
+ """
22
+ data = request.get_json(force=True)
23
+
24
+ system_prompt = data.get("system_prompt", "You are a helpful AI.")
25
+ user_input = data.get("user_input", "")
26
+ temperature = float(data.get("temperature", 0.7))
27
+ mode = data.get("mode", "text")
28
+
29
+ # Text mode (default)
30
+ if mode == "text":
31
+ full_prompt = f"{system_prompt}\nUser: {user_input}\nAI:"
32
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
33
+ outputs = model.generate(
34
+ **inputs,
35
+ max_new_tokens=512,
36
+ do_sample=True,
37
+ temperature=temperature
38
+ )
39
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+ response = response.split("AI:")[-1].strip()
41
+ return jsonify({
42
+ "model": MODEL_NAME,
43
+ "response": response,
44
+ "mode": "text"
45
+ })
46
+
47
+ # You can later add multimodal branches here:
48
+ # - "image" -> call image generation pipeline
49
+ # - "audio" -> call speech-to-text / text-to-speech
50
+ # - "embedding" -> return vector embeddings
51
+
52
+ return jsonify({"error": f"Unsupported mode: {mode}"}), 400
53
+
54
+ if __name__ == '__main__':
55
+ app.run(host='0.0.0.0', port=7860)