Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| import json | |
| from typing import List, Literal | |
| import os | |
| import uvicorn | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL = "meta-llama/Llama-3.2-1B-Instruct" | |
| app = FastAPI() | |
| # base model and tokenizer | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| token=HF_TOKEN, | |
| dtype=torch.float32, #huggingface free tier only has cpu | |
| device_map="cpu", | |
| low_cpu_mem_usage=True | |
| ) | |
| base_model.config.use_cache = True | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # lora adapters | |
| adapter_paths = { | |
| "English": "./models/English", | |
| "Spanish": "./models/Spanish", | |
| "Korean": "./models/Korean" | |
| } | |
| # single PeftModel instance that switches adapters | |
| peft_model = None | |
| loaded_adapters = set() | |
| def load_adapter(language): | |
| global peft_model | |
| # first adapter: create the PeftModel | |
| if peft_model is None: | |
| peft_model = PeftModel.from_pretrained( | |
| base_model, | |
| adapter_paths[language], | |
| adapter_name=language, | |
| is_trainable=False | |
| ) | |
| peft_model.eval() | |
| loaded_adapters.add(language) | |
| return peft_model | |
| # load adapter if not already loaded | |
| if language not in loaded_adapters: | |
| peft_model.load_adapter(adapter_paths[language], adapter_name=language) | |
| loaded_adapters.add(language) | |
| # switch to the requested adapter | |
| peft_model.set_adapter(language) | |
| return peft_model | |
| # the input will be a list of messages that include system, user, and assistant prompts | |
| def generate_text_stream(messages, language, max_length=256, temperature=0.7): | |
| if language not in adapter_paths: | |
| yield f"Error: Language '{language}' not supported. Choose from: {list(adapter_paths.keys())}" | |
| return | |
| model = load_adapter(language) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, # provides assistant: so that it can start generating | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, # the key-value pairs in inputs are applied to this new dictinary | |
| "max_new_tokens": max_length, | |
| "temperature": temperature, | |
| "do_sample": True, # to stop greedy selection | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| "num_beams": 1, # keep only 1 sequence till the end | |
| "use_cache": True, #KV caching | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| for text in streamer: | |
| yield text | |
| thread.join() | |
| # using pydantic to ensure data schemas | |
| class Message(BaseModel): | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class GenerateRequest(BaseModel): | |
| messages: List[Message] | |
| language: str | |
| max_length: int = 256 | |
| temperature: float = 0.7 | |
| # fastAPI endpoints | |
| # return information about the API | |
| def read_api(): | |
| return { | |
| "message": "Multi-language Chatbot API", | |
| "languages": list(adapter_paths.keys()), | |
| "device": "CPU 16GB in Huggingface Space", | |
| "endpoints": { | |
| "POST /api/generate": "Generate with streaming", | |
| "GET /api/languages": "List available languages" | |
| } | |
| } | |
| # return information about the langauge of the model | |
| def get_languages(): | |
| return { | |
| "languages": list(adapter_paths.keys()), | |
| } | |
| # providing a response through a stream | |
| async def generate_stream_api(request: GenerateRequest): | |
| # because pydantic uses Message class | |
| # this needs to be converted again to plain dictionary | |
| messages_dicts = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
| def event_generator(): | |
| try: | |
| for token in generate_text_stream( | |
| messages_dicts, | |
| request.language, | |
| request.max_length, | |
| request.temperature | |
| ): | |
| yield f"data: {json.dumps({'token': token})}\n\n" | |
| yield f"data: [DONE]\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| # SSE is implemeted | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", # SSE content type | |
| headers={ | |
| "Cache-Control": "no-cache", # Don't cache streaming responses | |
| "Connection": "keep-alive", # Keep connection open | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| def chat_gradio(message, history, language, system_prompt, max_length, temperature): | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # only uses the last 10 messages to keep within context limit | |
| messages.extend(history[-10:]) | |
| user_msg = {"role": "user", "content": message} | |
| messages.append(user_msg) | |
| assistant_msg = {"role": "assistant", "content": ""} | |
| for token in generate_text_stream( | |
| messages, | |
| language, | |
| max_length, | |
| temperature | |
| ): | |
| assistant_msg["content"] += token | |
| yield history + [user_msg, assistant_msg] | |
| with gr.Blocks( | |
| title="Language Learning Chatbot", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=500, | |
| show_copy_button=True, # Let users copy messages | |
| type="messages" | |
| ) | |
| # User input | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Type your message here and press Enter...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear Chat", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ⚙️ Settings") | |
| language_dropdown = gr.Dropdown( | |
| choices=list(adapter_paths.keys()), | |
| label="Language", | |
| value=list(adapter_paths.keys())[0], | |
| info="Select the language model to use" | |
| ) | |
| system_prompt_input = gr.Textbox( | |
| label="System Prompt (Optional)", | |
| placeholder="e.g., You are a helpful assistant...", | |
| lines=3, | |
| info="Set the assistant's behavior" | |
| ) | |
| max_length_slider = gr.Slider( | |
| minimum=50, | |
| maximum=512, | |
| value=256, | |
| step=1, | |
| label="Max Length (tokens)", | |
| info="Maximum tokens to generate" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature", | |
| info="Higher = more creative" | |
| ) | |
| # handling enter key in textbox | |
| msg.submit( | |
| fn=chat_gradio, | |
| inputs=[msg, chatbot, language_dropdown, system_prompt_input, max_length_slider, temperature_slider], | |
| outputs=[chatbot], # Update chatbot with streaming response | |
| ).then( | |
| fn=lambda: gr.update(value=""), # Clear input after sending | |
| outputs=[msg] | |
| ) | |
| # Handle button click | |
| submit_btn.click( | |
| fn=chat_gradio, | |
| inputs=[msg, chatbot, language_dropdown, system_prompt_input, max_length_slider, temperature_slider], | |
| outputs=[chatbot], | |
| ).then( | |
| fn=lambda: gr.update(value=""), | |
| outputs=[msg] | |
| ) | |
| # Clear chat button | |
| clear_btn.click( | |
| fn=lambda: None, # Return None to clear chatbot | |
| outputs=[chatbot], | |
| queue=False # Don't queue this action | |
| ) | |
| demo.queue(False) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |