Jamiiwej2903 commited on
Commit
b71794b
·
verified ·
1 Parent(s): 0cb474d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -10
main.py CHANGED
@@ -1,13 +1,14 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
  import base64
5
  import uvicorn
 
6
 
7
  app = FastAPI()
8
 
9
- # Initialize the InferenceClient
10
- client = InferenceClient("facebook/musicgen-small")
11
 
12
  class Item(BaseModel):
13
  prompt: str
@@ -16,15 +17,25 @@ class Item(BaseModel):
16
  @app.post("/generate/")
17
  async def generate_music(item: Item):
18
  try:
19
- # Call the Hugging Face Inference API
20
- audio = client.audio_generation(
21
- item.prompt,
22
- max_new_tokens=256,
23
- duration=item.duration
24
- )
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Convert audio to base64
27
- audio_base64 = base64.b64encode(audio).decode('utf-8')
28
 
29
  return {"audio": audio_base64}
30
  except Exception as e:
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ import requests
4
  import base64
5
  import uvicorn
6
+ import os
7
 
8
  app = FastAPI()
9
 
10
+ # Get the Hugging Face API token from environment variable
11
+ API_TOKEN = os.environ.get("HF_API_TOKEN")
12
 
13
  class Item(BaseModel):
14
  prompt: str
 
17
  @app.post("/generate/")
18
  async def generate_music(item: Item):
19
  try:
20
+ # Prepare the API request
21
+ API_URL = "https://api-inference.huggingface.co/models/facebook/musicgen-small"
22
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
23
+ payload = {
24
+ "inputs": item.prompt,
25
+ "parameters": {"duration": item.duration}
26
+ }
27
+
28
+ # Make the API call
29
+ response = requests.post(API_URL, headers=headers, json=payload)
30
+
31
+ if response.status_code != 200:
32
+ return {"error": f"API request failed with status code {response.status_code}: {response.text}"}
33
+
34
+ # The response content is the audio file
35
+ audio_content = response.content
36
 
37
  # Convert audio to base64
38
+ audio_base64 = base64.b64encode(audio_content).decode('utf-8')
39
 
40
  return {"audio": audio_base64}
41
  except Exception as e: