Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| import json | |
| from typing import List, Literal, Optional | |
| import os | |
| import uuid | |
| import time | |
| # Use more CPU threads for faster inference | |
| torch.set_num_threads(4) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL = "meta-llama/Llama-3.2-1B-Instruct" | |
| app = FastAPI() | |
| # base model for finetuned (LoRA) inference | |
| finetuned_base = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| token=HF_TOKEN, | |
| dtype=torch.bfloat16, # faster than float32, matches GPU training | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| attn_implementation="sdpa", # PyTorch optimized attention | |
| ) | |
| finetuned_base.config.use_cache = True | |
| # separate base model for comparison (no LoRA) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, | |
| token=HF_TOKEN, | |
| dtype=torch.bfloat16, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| attn_implementation="sdpa", | |
| ) | |
| base_model.config.use_cache = True | |
| base_model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # lora adapters | |
| adapter_paths = { | |
| "English": "./models/English", | |
| "Spanish": "./models/Spanish", | |
| "Korean": "./models/Korean" | |
| } | |
| languages = list(adapter_paths.keys()) | |
| # Create PeftModel with first adapter | |
| peft_model = PeftModel.from_pretrained( | |
| finetuned_base, | |
| adapter_paths[languages[0]], | |
| adapter_name=languages[0], | |
| is_trainable=False | |
| ) | |
| # Load remaining adapters | |
| for lang in languages[1:]: | |
| peft_model.load_adapter(adapter_paths[lang], adapter_name=lang) | |
| peft_model.eval() | |
| print("All adapters ready.") | |
| # base model generation (no LoRA) | |
| def generate_base_model_stream(messages, max_length=256, temperature=0.7): | |
| print(f"Base model (no LoRA)") | |
| print(f"Messages: {messages}") | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(base_model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, | |
| "max_new_tokens": max_length, | |
| "temperature": temperature, | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| "num_beams": 1, | |
| "use_cache": True, | |
| } | |
| thread = Thread(target=base_model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| for text in streamer: | |
| yield text | |
| thread.join() | |
| # the input will be a list of messages that include system, user, and assistant prompts | |
| def generate_text_stream(messages, language, max_length=256, temperature=0.7): | |
| print(f"Language adapter: {language}") | |
| print(f"Messages: {messages}") | |
| if language not in adapter_paths: | |
| yield f"Error: Language '{language}' not supported. Choose from: {list(adapter_paths.keys())}" | |
| return | |
| peft_model.set_adapter(language) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, # provides assistant: so that it can start generating | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(peft_model.device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| generation_kwargs = { | |
| **inputs, # the key-value pairs in inputs are applied to this new dictinary | |
| "max_new_tokens": max_length, | |
| "temperature": temperature, | |
| "do_sample": True, # to stop greedy selection | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| "num_beams": 1, # keep only 1 sequence till the end | |
| "use_cache": True, #KV caching | |
| } | |
| thread = Thread(target=peft_model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| for text in streamer: | |
| yield text | |
| thread.join() | |
| # using pydantic to ensure data schemas | |
| class Message(BaseModel): | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class GenerateRequest(BaseModel): | |
| messages: List[Message] | |
| language: str | |
| max_length: int = 256 | |
| temperature: float = 0.7 | |
| # OpenAI-compatible request format for InferenceClient | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = "default" | |
| messages: List[Message] | |
| max_tokens: Optional[int] = 256 | |
| temperature: Optional[float] = 0.7 | |
| stream: Optional[bool] = True | |
| # fastAPI endpoints | |
| # return information about the API | |
| def read_api(): | |
| return { | |
| "message": "Multi-language Chatbot API", | |
| "languages": list(adapter_paths.keys()), | |
| "device": "CPU 16GB in Huggingface Space", | |
| "endpoints": { | |
| "POST /api/generate": "Generate with streaming", | |
| "GET /api/languages": "List available languages" | |
| } | |
| } | |
| # return information about the langauge of the model | |
| def get_languages(): | |
| return { | |
| "languages": list(adapter_paths.keys()), | |
| } | |
| # providing a response through a stream | |
| async def generate_stream_api(request: GenerateRequest): | |
| # because pydantic uses Message class | |
| # this needs to be converted again to plain dictionary | |
| messages_dicts = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
| def event_generator(): | |
| try: | |
| for token in generate_text_stream( | |
| messages_dicts, | |
| request.language, | |
| request.max_length, | |
| request.temperature | |
| ): | |
| yield f"data: {json.dumps({'token': token})}\n\n" | |
| yield f"data: [DONE]\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| # SSE is implemeted | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", # SSE content type | |
| headers={ | |
| "Cache-Control": "no-cache", # Don't cache streaming responses | |
| "Connection": "keep-alive", # Keep connection open | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # OpenAI-compatible endpoint for HuggingFace InferenceClient | |
| # Pass language via the `model` field (e.g., "English", "Spanish", "Korean") | |
| async def chat_completions(request: ChatCompletionRequest): | |
| messages_dicts = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
| # Use model field as language selector, default to English if invalid | |
| language = request.model if request.model in adapter_paths else "English" | |
| chat_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" | |
| created = int(time.time()) | |
| def event_generator(): | |
| try: | |
| for token in generate_text_stream( | |
| messages_dicts, | |
| language, | |
| request.max_tokens or 256, | |
| request.temperature or 0.7 | |
| ): | |
| chunk = { | |
| "id": chat_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": language, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": token}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| # Final chunk with finish_reason | |
| final_chunk = { | |
| "id": chat_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": language, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| error_chunk = {"error": {"message": str(e), "type": "server_error"}} | |
| yield f"data: {json.dumps(error_chunk)}\n\n" | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| def chat_base_model(message, history, system_prompt, max_length, temperature): | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.extend(history[-10:]) | |
| user_msg = {"role": "user", "content": message} | |
| messages.append(user_msg) | |
| assistant_msg = {"role": "assistant", "content": ""} | |
| for token in generate_base_model_stream( | |
| messages, | |
| max_length, | |
| temperature | |
| ): | |
| assistant_msg["content"] += token | |
| yield history + [user_msg, assistant_msg] | |
| def chat_finetuned(message, history, language, system_prompt, max_length, temperature): | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.extend(history[-10:]) | |
| user_msg = {"role": "user", "content": message} | |
| messages.append(user_msg) | |
| assistant_msg = {"role": "assistant", "content": ""} | |
| for token in generate_text_stream( | |
| messages, | |
| language, | |
| max_length, | |
| temperature | |
| ): | |
| assistant_msg["content"] += token | |
| yield history + [user_msg, assistant_msg] | |
| with gr.Blocks( | |
| title="Language Learning Chatbot", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Base Model (No LoRA)") | |
| chatbot_base = gr.Chatbot( | |
| label="Base Model", | |
| height=400, | |
| type="messages" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Finetuned Model (LoRA)") | |
| chatbot_finetuned = gr.Chatbot( | |
| label="Finetuned Model", | |
| height=400, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your message", | |
| placeholder="Type your message here and press Enter...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear Both Chats", scale=1) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Settings") | |
| language_dropdown = gr.Dropdown( | |
| choices=list(adapter_paths.keys()), | |
| label="Language (for Finetuned Model)", | |
| value=list(adapter_paths.keys())[0], | |
| info="Select the language adapter to use" | |
| ) | |
| system_prompt_input = gr.Textbox( | |
| label="System Prompt (Optional)", | |
| placeholder="e.g., You are a helpful assistant...", | |
| lines=3, | |
| info="Shared between both models" | |
| ) | |
| max_length_slider = gr.Slider( | |
| minimum=50, | |
| maximum=512, | |
| value=256, | |
| step=1, | |
| label="Max Length (tokens)", | |
| info="Maximum tokens to generate" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature", | |
| info="Higher = more creative" | |
| ) | |
| # handling enter key in textbox - send to both models | |
| msg.submit( | |
| fn=chat_base_model, | |
| inputs=[msg, chatbot_base, system_prompt_input, max_length_slider, temperature_slider], | |
| outputs=[chatbot_base], | |
| ) | |
| msg.submit( | |
| fn=chat_finetuned, | |
| inputs=[msg, chatbot_finetuned, language_dropdown, system_prompt_input, max_length_slider, temperature_slider], | |
| outputs=[chatbot_finetuned], | |
| ).then( | |
| fn=lambda: gr.update(value=""), | |
| outputs=[msg] | |
| ) | |
| # Handle button click - send to both models | |
| submit_btn.click( | |
| fn=chat_base_model, | |
| inputs=[msg, chatbot_base, system_prompt_input, max_length_slider, temperature_slider], | |
| outputs=[chatbot_base], | |
| ) | |
| submit_btn.click( | |
| fn=chat_finetuned, | |
| inputs=[msg, chatbot_finetuned, language_dropdown, system_prompt_input, max_length_slider, temperature_slider], | |
| outputs=[chatbot_finetuned], | |
| ).then( | |
| fn=lambda: gr.update(value=""), | |
| outputs=[msg] | |
| ) | |
| # Clear both chats | |
| clear_btn.click( | |
| fn=lambda: (None, None), | |
| outputs=[chatbot_base, chatbot_finetuned], | |
| queue=False | |
| ) | |
| demo.queue(False) | |
| app = gr.mount_gradio_app(app, demo, path="/") |