SyedSyab commited on
Commit
b114468
·
1 Parent(s): 7b031dc

Convert to Gradio interface with API compatibility

Browse files
Files changed (2) hide show
  1. app.py +70 -18
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,9 +1,7 @@
1
- from flask import Flask, request, jsonify
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
- import os
5
-
6
- app = Flask(__name__)
7
 
8
  # Load your model once
9
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
@@ -21,16 +19,46 @@ def get_embedding(text):
21
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
22
  return embeddings.squeeze().tolist()
23
 
24
- @app.route('/api/predict', methods=['POST'])
25
- def predict():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- data = request.get_json()
 
 
28
  if not data or 'data' not in data:
29
- return jsonify({'error': 'Missing data field'}), 400
30
 
31
  texts = data['data']
32
  if not isinstance(texts, list):
33
- return jsonify({'error': 'Data must be a list of texts'}), 400
34
 
35
  # Generate embeddings for each text
36
  embeddings = []
@@ -39,17 +67,41 @@ def predict():
39
  embedding = get_embedding(text)
40
  embeddings.append(embedding)
41
  else:
42
- return jsonify({'error': 'All items in data must be strings'}), 400
43
-
44
- return jsonify({'data': embeddings})
45
 
 
46
  except Exception as e:
47
- return jsonify({'error': str(e)}), 500
48
 
49
- @app.route('/health', methods=['GET'])
50
- def health():
51
- return jsonify({'status': 'healthy'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  if __name__ == '__main__':
54
- port = int(os.environ.get('PORT', 7860))
55
- app.run(host='0.0.0.0', port=port)
 
1
+ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
+ import json
 
 
5
 
6
  # Load your model once
7
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
 
19
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
20
  return embeddings.squeeze().tolist()
21
 
22
+ def predict_texts(texts):
23
+ """Generate embeddings for a list of texts (for API compatibility)"""
24
+ if isinstance(texts, str):
25
+ # If single text, convert to list
26
+ texts = [texts]
27
+
28
+ if not isinstance(texts, list):
29
+ return "Error: Input must be a list of texts or a single text string"
30
+
31
+ # Generate embeddings for each text
32
+ embeddings = []
33
+ for text in texts:
34
+ if isinstance(text, str):
35
+ embedding = get_embedding(text)
36
+ embeddings.append(embedding)
37
+ else:
38
+ return f"Error: All items must be strings, got {type(text)}"
39
+
40
+ return embeddings
41
+
42
+ def predict_single_text(text):
43
+ """Generate embedding for a single text (for Gradio interface)"""
44
+ if not text or not text.strip():
45
+ return "Please enter some text to generate embeddings."
46
+
47
+ embedding = get_embedding(text.strip())
48
+ return f"Embedding (first 10 values): {embedding[:10]}...\nFull embedding has {len(embedding)} dimensions."
49
+
50
+ def predict_api(json_str):
51
+ """Handle API calls from backend - expects JSON string with {"data": ["text1", "text2", ...]}"""
52
  try:
53
+ import json
54
+ data = json.loads(json_str)
55
+
56
  if not data or 'data' not in data:
57
+ return json.dumps({'error': 'Missing data field'})
58
 
59
  texts = data['data']
60
  if not isinstance(texts, list):
61
+ return json.dumps({'error': 'Data must be a list of texts'})
62
 
63
  # Generate embeddings for each text
64
  embeddings = []
 
67
  embedding = get_embedding(text)
68
  embeddings.append(embedding)
69
  else:
70
+ return json.dumps({'error': 'All items must be strings'})
 
 
71
 
72
+ return json.dumps({'data': embeddings})
73
  except Exception as e:
74
+ return json.dumps({'error': str(e)})
75
 
76
+ # Create API interface (this will create /api/predict endpoint)
77
+ api_interface = gr.Interface(
78
+ fn=predict_api,
79
+ inputs=gr.Textbox(),
80
+ outputs=gr.Textbox(),
81
+ api_name="predict"
82
+ )
83
+
84
+ # Create web interface
85
+ web_interface = gr.Interface(
86
+ fn=predict_single_text,
87
+ inputs=gr.Textbox(lines=3, placeholder="Enter text to generate embeddings..."),
88
+ outputs=gr.Textbox(label="Embedding Result"),
89
+ title="Text Embedding Generator",
90
+ description="Generate embeddings for text using sentence-transformers/all-MiniLM-L6-v2 model",
91
+ examples=[
92
+ ["Hello world"],
93
+ ["This is a test sentence for embedding generation."],
94
+ ["Machine learning is transforming the world."]
95
+ ]
96
+ )
97
+
98
+ # Launch both interfaces
99
+ if __name__ == '__main__':
100
+ gr.TabbedInterface([web_interface, api_interface], ["Web UI", "API"]).launch(
101
+ server_name="0.0.0.0",
102
+ server_port=7860,
103
+ share=True
104
+ )
105
 
106
  if __name__ == '__main__':
107
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- Flask==2.3.3
2
  transformers==4.36.0
3
  torch>=2.0.0
4
  numpy>=1.21.0
 
1
+ gradio==4.36.0
2
  transformers==4.36.0
3
  torch>=2.0.0
4
  numpy>=1.21.0