app-dfsrfx-18 / app.py
AiCoderv2's picture
Update Gradio app with multiple files
f13f1c5 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import os
import torch
# Load the model and tokenizer for a coding expert AI
# Using Phi-2 which is good for coding and conversational tasks
token = os.getenv('HF_TOKEN')
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
def chat(message, history):
# Build conversation prompt
prompt = ""
for user_msg, bot_msg in history:
if user_msg:
prompt += f"User: {user_msg}\n"
if bot_msg:
prompt += f"Assistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate response with streaming
generated_tokens = []
with torch.no_grad():
for _ in range(100): # Limit to prevent infinite generation
outputs = model(**inputs)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), num_samples=1)
generated_tokens.append(next_token.item())
# Yield partial response
current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
yield current_text
# Check for end of response (simple heuristic: if ends with newline or period)
if current_text.endswith(('\n', '.', '!', '?')) and len(current_text) > 10:
break
# Update inputs for next token
inputs = torch.cat([inputs['input_ids'], next_token], dim=-1)
# Final yield
final_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
yield final_response
# Create Gradio interface with streaming enabled
demo = gr.ChatInterface(
fn=chat,
title="Coding Expert AI Chatbot",
description="Chat with a coding expert AI powered by Phi-2. It can help with programming questions and general conversations. <a href='https://huggingface.co/spaces/akhaliq/anycoder' target='_blank'>Built with anycoder</a>",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch()