File size: 1,767 Bytes
665309d
 
761983b
c61da1e
665309d
e439552
badc904
754da99
 
9d4f74f
 
 
 
 
e439552
754da99
 
 
 
790a676
badc904
665309d
 
 
754da99
790a676
754da99
 
 
 
 
790a676
 
 
754da99
790a676
67bfdbf
754da99
 
 
665309d
37b6d19
2240a2e
790a676
d969ee4
 
37b6d19
 
 
 
 
 
 
 
 
 
665309d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import spaces
import torch
import torch.amp as amp

from transformers import MistralForCausalLM, LlamaTokenizer, pipeline

repo_id = "appmana/Cosmos-1.0-Prompt-Upsampler-12B-Text2World-hf"

model = MistralForCausalLM.from_pretrained(
    repo_id,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained(repo_id)
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16
)

@spaces.GPU
def upsample(prompt):
    template = tokenizer.apply_chat_template(
        [{"role": "user", "content": f"Upsample the short caption to a long caption: {prompt}"}],
        tokenize=False,
        add_generation_prompt=True
    )
    response = pipe(
        template,
        min_p=0.01,
        top_p=0.95,
        top_k=40,
        do_sample=True,
        temperature=0.2,
        max_new_tokens=512,
        pad_token_id=tokenizer.eos_token_id
    )
    return response[0]["generated_text"]

demo = gr.Interface(
    title="NVIDIA Cosmos 🌌 Prompt Upsampler",
    description="""Upsample prompts using NVIDIA's 12B Cosmos model, based on Mistral NeMo 12B. This space uses the HuggingFace Transformers version at bfloat16 precision.
    
    [[cosmos]](https://huggingface.co/nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World) [[transformers]](https://huggingface.co/appmana/Cosmos-1.0-Prompt-Upsampler-12B-Text2World-hf) [[gguf]](https://huggingface.co/mradermacher/Cosmos-1.0-Prompt-Upsampler-12B-Text2World-hf-GGUF)""",
    fn=upsample,
    inputs=gr.Text(
        label="Prompt",
        value="A dog playing with a ball."
    ),
    outputs=gr.Text(
        label="Upsampled Prompt",
        interactive=False
    )
)
demo.launch()