junchenfu commited on
Commit
159a3dd
·
verified ·
1 Parent(s): a44f788

Upload LLMPopcorn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. LLMPopcorn.py +102 -0
LLMPopcorn.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
9
+ # Set random seed
10
+ SEED = 42
11
+ torch.manual_seed(SEED)
12
+ np.random.seed(SEED)
13
+ random.seed(SEED)
14
+
15
+ # Input file and output directory
16
+ input_file = "abstract_prompts.txt"
17
+ output_dir = "baseline_concrete_outputsf"
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ # Model name (example)
21
+ LLAMA_MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
22
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME)
23
+
24
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
25
+ model_llama = AutoModelForCausalLM.from_pretrained(
26
+ LLAMA_MODEL_NAME,
27
+ device_map="auto",
28
+ torch_dtype=torch.bfloat16,
29
+ quantization_config=quantization_config
30
+ )
31
+
32
+ # Set up pipeline
33
+ llama_pipeline = pipeline(
34
+ "text-generation",
35
+ model=model_llama,
36
+ tokenizer=tokenizer,
37
+ max_new_tokens=5000,
38
+ temperature=0.7,
39
+ top_p=0.9,
40
+ repetition_penalty=1.1,
41
+ do_sample=True
42
+ )
43
+
44
+ # Define a function to generate a valid filename from a query
45
+ def sanitize_filename(filename: str) -> str:
46
+ # Remove characters not suitable for filenames, truncate if too long
47
+ filename = filename.strip()
48
+ filename = re.sub(r'[\\/*?:"<>|]', "_", filename)
49
+ # For safety, truncate filename if query is too long
50
+ if len(filename) > 100:
51
+ filename = filename[:100]
52
+ return filename
53
+
54
+ with open(input_file, "r", encoding="utf-8") as f:
55
+ lines = f.readlines()
56
+
57
+ # Process each line
58
+ for line in tqdm(lines):
59
+ query = line.strip()
60
+ if not query:
61
+ continue
62
+
63
+ # Prepare the LLM input prompt
64
+ messages = [
65
+ {
66
+ "role": "system",
67
+ "content": (
68
+ "Now that you're a talented video creator with a wealth of ideas, you need to think from the user's perspective and after that generate the most popular video title, "
69
+ "an AI-generated cover prompt, and a 3-second AI-generated video prompt."
70
+ )
71
+ },
72
+ {
73
+ "role": "user",
74
+ "content": (
75
+ f"Below is the user query:\n\n{query}\n\n"
76
+ "Final Answer Requirements:\n"
77
+ "- A single line for the final generated Title (MAX_length = 50).\n"
78
+ "- A single paragraph for the Cover Prompt.\n"
79
+ "- A single paragraph for the Video Prompt (3-second).\n\n"
80
+ "Now, based on the above reasoning, generate the response in JSON format. Here is an example:\n"
81
+ "{\n"
82
+ ' "title": "Unveiling the Legacy of Ancient Rome: Rise, Glory, and Downfall.",\n'
83
+ ' "cover_prompt": "Generate an image of a Roman Emperor standing proudly in front of the Colosseum, with a subtle sunset backdrop, highlighting the contrast between the ancient structure.",\n'
84
+ ' "video_prompt": "Open with a 3-second aerial shot of the Roman Forum, showcasing the sprawling ancient ruins against a clear blue sky, before zooming in on a singular, imposing structure like the Colosseum."\n'
85
+ "}\n"
86
+ "Please provide your answer following this exact JSON template for the response."
87
+ )
88
+ }
89
+ ]
90
+
91
+ # Call the LLM for inference
92
+ response = llama_pipeline(messages, num_return_sequences=1)
93
+ final_output = response[0]["generated_text"]
94
+
95
+ # Determine output file name and save
96
+ output_filename = sanitize_filename(query) + ".txt"
97
+ output_path = os.path.join(output_dir, output_filename)
98
+
99
+ with open(output_path, "w", encoding="utf-8") as out_f:
100
+ out_f.write(final_output[2]['content'])
101
+
102
+ print(f"Processed query: {query} -> {output_path}")