| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
| app = FastAPI(title="Artist Description Generator") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| device_map="auto", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| load_in_4bit=torch.cuda.is_available() |
| ) |
|
|
| class ArtistInput(BaseModel): |
| artist_name: str |
| country: str | None = None |
| genres: list[str] = [] |
| top_tracks: list[str] = [] |
| metrics: dict = {} |
|
|
| def build_prompt(data: ArtistInput) -> str: |
| return f""" |
| You are writing a factual artist description for a music analytics platform. |
| |
| Rules: |
| - Use ONLY the provided data |
| - Do NOT invent awards, numbers, or events |
| - If data is missing, omit it |
| - Keep it concise and neutral |
| - 3–5 sentences maximum |
| |
| Artist name: {data.artist_name} |
| Country: {data.country} |
| Genres: {", ".join(data.genres)} |
| Top tracks: {", ".join(data.top_tracks)} |
| Metrics: {data.metrics} |
| |
| Artist description: |
| """.strip() |
|
|
| @app.post("/generate") |
| def generate_description(data: ArtistInput): |
| prompt = build_prompt(data) |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=180, |
| temperature=0.3, |
| top_p=0.9, |
| do_sample=True |
| ) |
|
|
| text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
| description = text.split("Artist description:")[-1].strip() |
|
|
| return { |
| "artist": data.artist_name, |
| "description": description |
| } |
|
|