Spaces:
Runtime error
Runtime error
| import time | |
| import base64 | |
| import gradio as gr | |
| from sentence_transformers import SentenceTransformer | |
| import httpx | |
| import json | |
| from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat | |
| minilm = SentenceTransformer('all-MiniLM-L6-v2') | |
| mubert_tags_embeddings = get_mubert_tags_embeddings(minilm) | |
| def get_track_by_tags(tags, pat, duration, maxit=20, loop=False): | |
| if loop: | |
| mode = "loop" | |
| else: | |
| mode = "track" | |
| r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', | |
| json={ | |
| "method": "RecordTrackTTM", | |
| "params": { | |
| "pat": pat, | |
| "duration": duration, | |
| "tags": tags, | |
| "mode": mode | |
| } | |
| }) | |
| rdata = json.loads(r.text) | |
| assert rdata['status'] == 1, rdata['error']['text'] | |
| trackurl = rdata['data']['tasks'][0]['download_link'] | |
| print('Generating track ', end='') | |
| for i in range(maxit): | |
| r = httpx.get(trackurl) | |
| if r.status_code == 200: | |
| return trackurl | |
| time.sleep(1) | |
| def generate_track_by_prompt(prompt): | |
| try: | |
| pat = get_pat("mail@mail.com") | |
| _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0] | |
| result = get_track_by_tags(tags, pat, int(30), loop=False) | |
| print(result) | |
| return result | |
| except Exception as e: | |
| return str(e) | |
| iface = gr.Interface(fn=generate_track_by_prompt, inputs=["text"], outputs=[gr.Text(label="Result")]) | |
| iface.queue(max_size=32, concurrency_count=20) | |
| iface.launch() |