Spaces:
Sleeping
Sleeping
| """Gradio interface for Hugging Face inference.""" | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| from app.config import get_settings | |
| settings = get_settings() | |
| client = InferenceClient(model=settings.model_name, token=settings.api_token) | |
| def _predict(text: str) -> str: | |
| """Run inference on the input text.""" | |
| if not text.strip(): | |
| return "Please enter some text." | |
| task = settings.task | |
| try: | |
| if task in ("text-classification", "sentiment-analysis"): | |
| results = client.text_classification(text) | |
| output = "\n".join( | |
| [f"{r['label']}: {r['score']:.2%}" for r in results] | |
| ) | |
| elif task == "text-generation": | |
| output = client.text_generation(text, max_new_tokens=100) | |
| elif task == "summarization": | |
| output = client.summarization(text) | |
| elif task == "translation": | |
| output = client.translation(text) | |
| elif task == "fill-mask": | |
| results = client.fill_mask(text) | |
| output = "\n".join( | |
| [f"{r['token_str']}: {r['score']:.2%}" for r in results] | |
| ) | |
| else: | |
| output = str(client.post(json={"inputs": text})) | |
| return output | |
| except Exception as e: | |
| return f"Error: {e}" | |
| # Apply @spaces.GPU decorator only on HF Spaces | |
| if SPACES_AVAILABLE: | |
| predict = spaces.GPU(duration=60)(_predict) | |
| else: | |
| predict = _predict | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text here...", | |
| lines=4, | |
| ), | |
| outputs=gr.Textbox(label="Result", lines=6), | |
| title="Hugging Face Inference", | |
| description=f"Model: **{settings.model_name}** | Task: **{settings.task}**", | |
| examples=[ | |
| ["I love this product! It's amazing."], | |
| ["This is the worst experience ever."], | |
| ["The weather is nice today."], | |
| ], | |
| flagging_mode="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |