recruai / app.py
SyedSyab's picture
Fix Gradio API to accept JSON directly for backend compatibility
21f4259
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import json
# Load your model once
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def get_embedding(text):
"""Generate embedding for a single text"""
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
# Use mean pooling over token embeddings
embeddings = outputs.last_hidden_state.mean(dim=1)
# Normalize the embeddings
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.squeeze().tolist()
def predict_texts(texts):
"""Generate embeddings for a list of texts (for API compatibility)"""
if isinstance(texts, str):
# If single text, convert to list
texts = [texts]
if not isinstance(texts, list):
return "Error: Input must be a list of texts or a single text string"
# Generate embeddings for each text
embeddings = []
for text in texts:
if isinstance(text, str):
embedding = get_embedding(text)
embeddings.append(embedding)
else:
return f"Error: All items must be strings, got {type(text)}"
return embeddings
def predict_single_text(text):
"""Generate embedding for a single text (for Gradio interface)"""
if not text or not text.strip():
return "Please enter some text to generate embeddings."
embedding = get_embedding(text.strip())
return f"Embedding (first 10 values): {embedding[:10]}...\nFull embedding has {len(embedding)} dimensions."
def predict_api(texts):
"""Handle API calls from backend - expects list of texts directly"""
try:
if not isinstance(texts, list):
return {'error': 'Input must be a list of texts'}
# Generate embeddings for each text
embeddings = []
for text in texts:
if isinstance(text, str):
embedding = get_embedding(text)
embeddings.append(embedding)
else:
return {'error': 'All items must be strings'}
return {'data': embeddings}
except Exception as e:
return {'error': str(e)}
# Create API interface (this will create /api/predict endpoint)
api_interface = gr.Interface(
fn=predict_api,
inputs=gr.JSON(), # Expects JSON input directly
outputs=gr.JSON(), # Returns JSON output directly
api_name="predict"
)
# Create web interface
web_interface = gr.Interface(
fn=predict_single_text,
inputs=gr.Textbox(lines=3, placeholder="Enter text to generate embeddings..."),
outputs=gr.Textbox(label="Embedding Result"),
title="Text Embedding Generator",
description="Generate embeddings for text using sentence-transformers/all-MiniLM-L6-v2 model",
examples=[
["Hello world"],
["This is a test sentence for embedding generation."],
["Machine learning is transforming the world."]
]
)
# Launch both interfaces
if __name__ == '__main__':
gr.TabbedInterface([web_interface, api_interface], ["Web UI", "API"]).launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
if __name__ == '__main__':
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)