Spaces:
Paused
Paused
| import os | |
| import re | |
| import torch | |
| from threading import Thread | |
| from typing import Iterator | |
| from mongoengine import connect, Document, StringField, SequenceField | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer | |
| from peft import PeftModel | |
| # Constants | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 930 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| # Description and License Texts | |
| DESCRIPTION = """ | |
| # ✨Storytell AI🧑🏽💻 | |
| Welcome to the **Storytell AI** space, crafted with care by Ranam & George. Dive into the world of educational storytelling with our model. This iteration of the Llama 2 model with 7 billion parameters is fine-tuned to generate educational stories that engage and educate. Enjoy a journey of discovery and creativity—your storytelling lesson begins here! You can prompt this model to explain any computer science concept. **Please check the examples below**. | |
| """ | |
| LICENSE = """ | |
| --- | |
| As a derivative work of [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) by Meta, | |
| this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). | |
| """ | |
| # GPU Check and add CPU warning | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| model_id = "meta-llama/Llama-2-7b-hf" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=False, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=bnb_config) | |
| storytell_model = PeftModel.from_pretrained(base_model, "ranamhamoud/storytell") | |
| storytell_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| storytell_tokenizer.pad_token = storytell_tokenizer.eos_token | |
| editing_model_id = "meta-llama/Llama-2-7b-chat-hf" | |
| editing_model = AutoModelForCausalLM.from_pretrained(editing_model_id, torch_dtype=torch.float16, device_map="auto") | |
| editing_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| editing_tokenizer.use_default_system_prompt = False | |
| # MongoDB Connection | |
| PASSWORD = os.environ.get("MONGO_PASS") | |
| connect(host=f"mongodb+srv://ranamhammoud11:{PASSWORD}@stories.zf5v52a.mongodb.net/") | |
| # MongoDB Document | |
| class Story(Document): | |
| message = StringField() | |
| content = StringField() | |
| story_id = SequenceField(primary_key=True) | |
| # Utility function for prompts | |
| def make_prompt(entry): | |
| return f"### Human: Don't repeat the assesments, limit to 500 words {entry} ### Assistant:" | |
| # f"TELL A STORY, RELATE TO COMPUTER SCIENCE, INCLUDE ASSESMENTS. MAKE IT REALISTIC AND AROUND 800 WORDS, END THE STORY WITH "THE END.": {entry}" | |
| def process_text(text): | |
| text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL) | |
| return text | |
| # Gradio Function | |
| def generate( | |
| model_choice: str, | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, | |
| temperature: float = 0.6, | |
| top_p: float = 0.7, | |
| top_k: int = 20, | |
| repetition_penalty: float = 1.0, | |
| ) -> Iterator[str]: | |
| conversation = [] | |
| if model_choice == "Storytell": | |
| model = storytell_model | |
| tokenizer = storytell_tokenizer | |
| else: | |
| model = editing_model | |
| tokenizer = editing_tokenizer | |
| for user, assistant in chat_history: | |
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
| conversation.append({"role": "user", "content": make_prompt(message)}) | |
| enc = tokenizer(make_prompt(message), return_tensors="pt", padding=True, truncation=True) | |
| input_ids = enc.input_ids.to(model.device) | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| processed_text = process_text(text) | |
| outputs.append(processed_text) | |
| output = "".join(outputs) | |
| yield output | |
| final_story = "".join(outputs) | |
| try: | |
| saved_story = Story(message=message, content=final_story).save() | |
| yield f"{final_story}\n\n Story saved with ID: {saved_story.story_id}" | |
| except Exception as e: | |
| yield f"Failed to save story: {str(e)}" | |
| # Gradio Interface Setup | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| stop_btn=None, | |
| additional_inputs=[gr.Dropdown(["Storytell", "HF Meta Llama 7b Chat"], label="Choose Model")], | |
| examples=[ | |
| ["Can you explain briefly to me what is the Python programming language?"], | |
| ["Could you please provide an explanation about the concept of recursion?"], | |
| ["Could you explain what a URL is?"] | |
| ], | |
| theme='shivi/calm_seafoam' | |
| ) | |
| # Gradio Web Interface | |
| with gr.Blocks(css="style.css",theme='shivi/calm_seafoam') as demo: | |
| gr.Markdown(DESCRIPTION) | |
| chat_interface.render() | |
| gr.Markdown(LICENSE) | |
| # Main Execution | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20) | |
| demo.launch(share=True) |